فهرست منبع

Updating config to be part of each provider object. Locking occurs exclusively on load and update. Once the config is loaded, we retain the instance unless updates occur.

Matt Bolt 6 سال پیش
والد
کامیت
05d9f77a5d
6فایلهای تغییر یافته به همراه410 افزوده شده و 417 حذف شده
  1. 64 86
      cloud/awsprovider.go
  2. 34 57
      cloud/azureprovider.go
  3. 32 46
      cloud/customprovider.go
  4. 60 78
      cloud/gcpprovider.go
  5. 4 150
      cloud/provider.go
  6. 216 0
      cloud/providerconfig.go

+ 64 - 86
cloud/awsprovider.go

@@ -61,6 +61,7 @@ type AWS struct {
 	ProjectID               string
 	DownloadPricingDataLock sync.RWMutex
 	ReservedInstances       []*AWSReservedInstance
+	Config                  *ProviderConfig
 	*CustomProvider
 }
 
@@ -256,7 +257,7 @@ func (aws *AWS) GetManagementPlatform() (string, error) {
 }
 
 func (aws *AWS) GetConfig() (*CustomPricing, error) {
-	c, err := GetCustomPricingData("aws.json")
+	c, err := aws.Config.GetCustomPricingData()
 	if c.Discount == "" {
 		c.Discount = "0%"
 	}
@@ -269,98 +270,75 @@ func (aws *AWS) GetConfig() (*CustomPricing, error) {
 	return c, nil
 }
 func (aws *AWS) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("aws.json")
-	if err != nil {
-		return nil, err
-	}
-
-	return configmapUpdate(c, configPathFor("aws.json"), a)
+	return aws.Config.UpdateFromMap(a)
 }
 
 func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("aws.json")
-	if err != nil {
-		return nil, err
-	}
-	if updateType == SpotInfoUpdateType {
-		a := AwsSpotFeedInfo{}
-		err := json.NewDecoder(r).Decode(&a)
-		if err != nil {
-			return nil, err
-		}
+	return aws.Config.Update(func(c *CustomPricing) error {
+		if updateType == SpotInfoUpdateType {
+			a := AwsSpotFeedInfo{}
+			err := json.NewDecoder(r).Decode(&a)
+			if err != nil {
+				return err
+			}
 
-		if err != nil {
-			return nil, err
-		}
-		c.ServiceKeyName = a.ServiceKeyName
-		c.ServiceKeySecret = a.ServiceKeySecret
-		c.SpotDataPrefix = a.Prefix
-		c.SpotDataBucket = a.BucketName
-		c.ProjectID = a.AccountID
-		c.SpotDataRegion = a.Region
-		c.SpotLabel = a.SpotLabel
-		c.SpotLabelValue = a.SpotLabelValue
-
-	} else if updateType == AthenaInfoUpdateType {
-		a := AwsAthenaInfo{}
-		err := json.NewDecoder(r).Decode(&a)
-		if err != nil {
-			return nil, err
-		}
-		c.AthenaBucketName = a.AthenaBucketName
-		c.AthenaRegion = a.AthenaRegion
-		c.AthenaDatabase = a.AthenaDatabase
-		c.AthenaTable = a.AthenaTable
-		c.ServiceKeyName = a.ServiceKeyName
-		c.ServiceKeySecret = a.ServiceKeySecret
-		c.ProjectID = a.AccountID
-	} else {
-		a := make(map[string]interface{})
-		err = json.NewDecoder(r).Decode(&a)
-		if err != nil {
-			return nil, err
-		}
-		for k, v := range a {
-			kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
-			vstr, ok := v.(string)
-			if ok {
-				err := SetCustomPricingField(c, kUpper, vstr)
-				if err != nil {
-					return nil, err
-				}
-			} else {
-				sci := v.(map[string]interface{})
-				sc := make(map[string]string)
-				for k, val := range sci {
-					sc[k] = val.(string)
+			c.ServiceKeyName = a.ServiceKeyName
+			c.ServiceKeySecret = a.ServiceKeySecret
+			c.SpotDataPrefix = a.Prefix
+			c.SpotDataBucket = a.BucketName
+			c.ProjectID = a.AccountID
+			c.SpotDataRegion = a.Region
+			c.SpotLabel = a.SpotLabel
+			c.SpotLabelValue = a.SpotLabelValue
+
+		} else if updateType == AthenaInfoUpdateType {
+			a := AwsAthenaInfo{}
+			err := json.NewDecoder(r).Decode(&a)
+			if err != nil {
+				return err
+			}
+			c.AthenaBucketName = a.AthenaBucketName
+			c.AthenaRegion = a.AthenaRegion
+			c.AthenaDatabase = a.AthenaDatabase
+			c.AthenaTable = a.AthenaTable
+			c.ServiceKeyName = a.ServiceKeyName
+			c.ServiceKeySecret = a.ServiceKeySecret
+			c.ProjectID = a.AccountID
+		} else {
+			a := make(map[string]interface{})
+			err := json.NewDecoder(r).Decode(&a)
+			if err != nil {
+				return err
+			}
+			for k, v := range a {
+				kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
+				vstr, ok := v.(string)
+				if ok {
+					err := SetCustomPricingField(c, kUpper, vstr)
+					if err != nil {
+						return err
+					}
+				} else {
+					sci := v.(map[string]interface{})
+					sc := make(map[string]string)
+					for k, val := range sci {
+						sc[k] = val.(string)
+					}
+					c.SharedCosts = sc //todo: support reflection/multiple map fields
 				}
-				c.SharedCosts = sc //todo: support reflection/multiple map fields
 			}
 		}
-	}
-	cj, err := json.Marshal(c)
-	if err != nil {
-		return nil, err
-	}
-
-	path := configPathFor("aws.json")
 
-	remoteEnabled := os.Getenv(remoteEnabled)
-	if remoteEnabled == "true" {
-		err = UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
-		if err != nil {
-			return nil, err
+		remoteEnabled := os.Getenv(remoteEnabled)
+		if remoteEnabled == "true" {
+			err := UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
+			if err != nil {
+				return err
+			}
 		}
-	}
-
-	configLock.Lock()
-	err = ioutil.WriteFile(path, cj, 0644)
-	configLock.Unlock()
 
-	if err != nil {
-		return nil, err
-	}
-	return c, nil
+		return nil
+	})
 }
 
 type awsKey struct {
@@ -477,7 +455,7 @@ func (aws *AWS) isPreemptible(key string) bool {
 func (aws *AWS) DownloadPricingData() error {
 	aws.DownloadPricingDataLock.Lock()
 	defer aws.DownloadPricingDataLock.Unlock()
-	c, err := GetCustomPricingData("aws.json")
+	c, err := aws.Config.GetCustomPricingData()
 	if err != nil {
 		klog.V(1).Infof("Error downloading default pricing data: %s", err.Error())
 	}
@@ -700,8 +678,8 @@ func (aws *AWS) DownloadPricingData() error {
 }
 
 // Stubbed NetworkPricing for AWS. Pull directly from aws.json for now
-func (c *AWS) NetworkPricing() (*Network, error) {
-	cpricing, err := GetCustomPricingData("aws.json")
+func (aws *AWS) NetworkPricing() (*Network, error) {
+	cpricing, err := aws.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}

+ 34 - 57
cloud/azureprovider.go

@@ -5,7 +5,6 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"net/url"
 	"os"
 	"regexp"
@@ -166,6 +165,7 @@ type Azure struct {
 	allPrices               map[string]*Node
 	DownloadPricingDataLock sync.RWMutex
 	Clientset               clustercache.ClusterCache
+	Config                  *ProviderConfig
 }
 
 type azureKey struct {
@@ -462,8 +462,8 @@ func (az *Azure) NodePricing(key Key) (*Node, error) {
 }
 
 // Stubbed NetworkPricing for Azure. Pull directly from azure.json for now
-func (c *Azure) NetworkPricing() (*Network, error) {
-	cpricing, err := GetCustomPricingData("azure.json")
+func (az *Azure) NetworkPricing() (*Network, error) {
+	cpricing, err := az.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}
@@ -547,72 +547,49 @@ func (az *Azure) AddServiceKey(url url.Values) error {
 }
 
 func (az *Azure) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("azure.json")
-	if err != nil {
-		return nil, err
-	}
-
-	return configmapUpdate(c, configPathFor("azure.json"), a)
+	return az.Config.UpdateFromMap(a)
 }
 
 func (az *Azure) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
 	defer az.DownloadPricingData()
 
-	c, err := GetCustomPricingData("azure.json")
-	if err != nil {
-		return nil, err
-	}
-
-	a := make(map[string]interface{})
-	err = json.NewDecoder(r).Decode(&a)
-	if err != nil {
-		return nil, err
-	}
-	for k, v := range a {
-		kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
-		vstr, ok := v.(string)
-		if ok {
-			err := SetCustomPricingField(c, kUpper, vstr)
-			if err != nil {
-				return nil, err
-			}
-		} else {
-			sci := v.(map[string]interface{})
-			sc := make(map[string]string)
-			for k, val := range sci {
-				sc[k] = val.(string)
+	return az.Config.Update(func(c *CustomPricing) error {
+		a := make(map[string]interface{})
+		err := json.NewDecoder(r).Decode(&a)
+		if err != nil {
+			return err
+		}
+		for k, v := range a {
+			kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
+			vstr, ok := v.(string)
+			if ok {
+				err := SetCustomPricingField(c, kUpper, vstr)
+				if err != nil {
+					return err
+				}
+			} else {
+				sci := v.(map[string]interface{})
+				sc := make(map[string]string)
+				for k, val := range sci {
+					sc[k] = val.(string)
+				}
+				c.SharedCosts = sc //todo: support reflection/multiple map fields
 			}
-			c.SharedCosts = sc //todo: support reflection/multiple map fields
 		}
-	}
-
-	cj, err := json.Marshal(c)
-	if err != nil {
-		return nil, err
-	}
 
-	remoteEnabled := os.Getenv(remoteEnabled)
-	if remoteEnabled == "true" {
-		err = UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
-		if err != nil {
-			return nil, err
+		remoteEnabled := os.Getenv(remoteEnabled)
+		if remoteEnabled == "true" {
+			err := UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
+			if err != nil {
+				return err
+			}
 		}
-	}
-
-	configPath := configPathFor("azure.json")
 
-	configLock.Lock()
-	err = ioutil.WriteFile(configPath, cj, 0644)
-	configLock.Unlock()
-
-	if err != nil {
-		return nil, err
-	}
-
-	return c, nil
+		return nil
+	})
 }
 func (az *Azure) GetConfig() (*CustomPricing, error) {
-	c, err := GetCustomPricingData("azure.json")
+	c, err := az.Config.GetCustomPricingData()
 	if c.Discount == "" {
 		c.Discount = "0%"
 	}

+ 32 - 46
cloud/customprovider.go

@@ -3,7 +3,6 @@ package cloud
 import (
 	"encoding/json"
 	"io"
-	"io/ioutil"
 	"net/url"
 	"strconv"
 	"strings"
@@ -27,6 +26,7 @@ type CustomProvider struct {
 	GPULabel                string
 	GPULabelValue           string
 	DownloadPricingDataLock sync.RWMutex
+	Config                  *ProviderConfig
 }
 
 type customProviderKey struct {
@@ -41,8 +41,8 @@ func (*CustomProvider) GetLocalStorageQuery(offset string) (string, error) {
 	return "", nil
 }
 
-func (*CustomProvider) GetConfig() (*CustomPricing, error) {
-	return GetCustomPricingData("default.json")
+func (cp *CustomProvider) GetConfig() (*CustomPricing, error) {
+	return cp.Config.GetCustomPricingData()
 }
 
 func (*CustomProvider) GetManagementPlatform() (string, error) {
@@ -54,60 +54,46 @@ func (*CustomProvider) ApplyReservedInstancePricing(nodes map[string]*Node) {
 }
 
 func (cp *CustomProvider) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("default.json")
-	if err != nil {
-		return nil, err
-	}
-
-	return configmapUpdate(c, configPathFor("default.json"), a)
+	return cp.Config.UpdateFromMap(a)
 }
 
 func (cp *CustomProvider) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("default.json")
-	if err != nil {
-		return nil, err
-	}
-
+	// Parse config updates from reader
 	a := make(map[string]interface{})
-	err = json.NewDecoder(r).Decode(&a)
+	err := json.NewDecoder(r).Decode(&a)
 	if err != nil {
 		return nil, err
 	}
-	for k, v := range a {
-		kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
-		vstr, ok := v.(string)
-		if ok {
-			err := SetCustomPricingField(c, kUpper, vstr)
-			if err != nil {
-				return nil, err
-			}
-		} else {
-			sci := v.(map[string]interface{})
-			sc := make(map[string]string)
-			for k, val := range sci {
-				sc[k] = val.(string)
+
+	// Update Config
+	c, err := cp.Config.Update(func(c *CustomPricing) error {
+		for k, v := range a {
+			kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
+			vstr, ok := v.(string)
+			if ok {
+				err := SetCustomPricingField(c, kUpper, vstr)
+				if err != nil {
+					return err
+				}
+			} else {
+				sci := v.(map[string]interface{})
+				sc := make(map[string]string)
+				for k, val := range sci {
+					sc[k] = val.(string)
+				}
+				c.SharedCosts = sc //todo: support reflection/multiple map fields
 			}
-			c.SharedCosts = sc //todo: support reflection/multiple map fields
 		}
-	}
 
-	cj, err := json.Marshal(c)
-	if err != nil {
-		return nil, err
-	}
-
-	configPath := configPathFor("default.json")
-
-	configLock.Lock()
-	err = ioutil.WriteFile(configPath, cj, 0644)
-	configLock.Unlock()
+		return nil
+	})
 
 	if err != nil {
 		return nil, err
 	}
+
 	defer cp.DownloadPricingData()
 	return c, nil
-
 }
 
 func (cp *CustomProvider) ClusterInfo() (map[string]string, error) {
@@ -168,7 +154,7 @@ func (cp *CustomProvider) DownloadPricingData() error {
 		m := make(map[string]*NodePrice)
 		cp.Pricing = m
 	}
-	p, err := GetCustomPricingData("default.json")
+	p, err := cp.Config.GetCustomPricingData()
 	if err != nil {
 		return err
 	}
@@ -213,8 +199,8 @@ func (*CustomProvider) QuerySQL(query string) ([]byte, error) {
 	return nil, nil
 }
 
-func (*CustomProvider) PVPricing(pvk PVKey) (*PV, error) {
-	cpricing, err := GetCustomPricingData("default.json")
+func (cp *CustomProvider) PVPricing(pvk PVKey) (*PV, error) {
+	cpricing, err := cp.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}
@@ -223,8 +209,8 @@ func (*CustomProvider) PVPricing(pvk PVKey) (*PV, error) {
 	}, nil
 }
 
-func (*CustomProvider) NetworkPricing() (*Network, error) {
-	cpricing, err := GetCustomPricingData("default.json")
+func (cp *CustomProvider) NetworkPricing() (*Network, error) {
+	cpricing, err := cp.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}

+ 60 - 78
cloud/gcpprovider.go

@@ -51,6 +51,7 @@ type GCP struct {
 	BillingDataDataset      string
 	DownloadPricingDataLock sync.RWMutex
 	ReservedInstances       []*GCPReservedInstance
+	Config                  *ProviderConfig
 	*CustomProvider
 }
 
@@ -86,7 +87,7 @@ func (gcp *GCP) GetLocalStorageQuery(offset string) (string, error) {
 }
 
 func (gcp *GCP) GetConfig() (*CustomPricing, error) {
-	c, err := GetCustomPricingData("gcp.json")
+	c, err := gcp.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}
@@ -119,97 +120,78 @@ func (gcp *GCP) GetManagementPlatform() (string, error) {
 }
 
 func (gcp *GCP) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("gcp.json")
-	if err != nil {
-		return nil, err
-	}
-
-	return configmapUpdate(c, configPathFor("gcp.json"), a)
+	return gcp.Config.UpdateFromMap(a)
 }
 
 func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
-	c, err := GetCustomPricingData("gcp.json")
-	if err != nil {
-		return nil, err
-	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	if updateType == BigqueryUpdateType {
-		a := BigQueryConfig{}
-		err = json.NewDecoder(r).Decode(&a)
-		if err != nil {
-			return nil, err
-		}
+	return gcp.Config.Update(func(c *CustomPricing) error {
+		if updateType == BigqueryUpdateType {
+			a := BigQueryConfig{}
+			err := json.NewDecoder(r).Decode(&a)
+			if err != nil {
+				return err
+			}
 
-		c.ProjectID = a.ProjectID
-		c.BillingDataDataset = a.BillingDataDataset
+			c.ProjectID = a.ProjectID
+			c.BillingDataDataset = a.BillingDataDataset
 
-		j, err := json.Marshal(a.Key)
-		if err != nil {
-			return nil, err
-		}
+			j, err := json.Marshal(a.Key)
+			if err != nil {
+				return err
+			}
 
-		keyPath := path + "key.json"
-		err = ioutil.WriteFile(keyPath, j, 0644)
-		if err != nil {
-			return nil, err
-		}
-	} else {
-		a := make(map[string]interface{})
-		err = json.NewDecoder(r).Decode(&a)
-		if err != nil {
-			return nil, err
-		}
-		for k, v := range a {
-			kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
-			vstr, ok := v.(string)
-			if ok {
-				err := SetCustomPricingField(c, kUpper, vstr)
-				if err != nil {
-					return nil, err
-				}
-			} else {
-				sci := v.(map[string]interface{})
-				sc := make(map[string]string)
-				for k, val := range sci {
-					sc[k] = val.(string)
+			path := os.Getenv("CONFIG_PATH")
+			if path == "" {
+				path = "/models/"
+			}
+
+			keyPath := path + "key.json"
+			err = ioutil.WriteFile(keyPath, j, 0644)
+			if err != nil {
+				return err
+			}
+		} else {
+			a := make(map[string]interface{})
+			err := json.NewDecoder(r).Decode(&a)
+			if err != nil {
+				return err
+			}
+			for k, v := range a {
+				kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
+				vstr, ok := v.(string)
+				if ok {
+					err := SetCustomPricingField(c, kUpper, vstr)
+					if err != nil {
+						return err
+					}
+				} else {
+					sci := v.(map[string]interface{})
+					sc := make(map[string]string)
+					for k, val := range sci {
+						sc[k] = val.(string)
+					}
+					c.SharedCosts = sc //todo: support reflection/multiple map fields
 				}
-				c.SharedCosts = sc //todo: support reflection/multiple map fields
 			}
 		}
-	}
 
-	cj, err := json.Marshal(c)
-	if err != nil {
-		return nil, err
-	}
-	remoteEnabled := os.Getenv(remoteEnabled)
-	if remoteEnabled == "true" {
-		err = UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
-		if err != nil {
-			return nil, err
+		remoteEnabled := os.Getenv(remoteEnabled)
+		if remoteEnabled == "true" {
+			err := UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
+			if err != nil {
+				return err
+			}
 		}
-	}
 
-	configPath := path + "gcp.json"
-
-	configLock.Lock()
-	err = ioutil.WriteFile(configPath, cj, 0644)
-	configLock.Unlock()
-	if err != nil {
-		return nil, err
-	}
-
-	return c, nil
+		return nil
+	})
 }
 
 // ExternalAllocations represents tagged assets outside the scope of kubernetes.
 // "start" and "end" are dates of the format YYYY-MM-DD
 // "aggregator" is the tag used to determine how to allocate those assets, ie namespace, pod, etc.
 func (gcp *GCP) ExternalAllocations(start string, end string, aggregator string, filterType string, filterValue string) ([]*OutOfClusterAllocation, error) {
-	c, err := GetCustomPricingData("gcp.json")
+	c, err := gcp.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}
@@ -234,7 +216,7 @@ func (gcp *GCP) ExternalAllocations(start string, end string, aggregator string,
 
 // QuerySQL should query BigQuery for billing data for out of cluster costs.
 func (gcp *GCP) QuerySQL(query string) ([]*OutOfClusterAllocation, error) {
-	c, err := GetCustomPricingData("gcp.json")
+	c, err := gcp.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}
@@ -687,7 +669,7 @@ func (gcp *GCP) parsePages(inputKeys map[string]Key, pvKeys map[string]PVKey) (m
 func (gcp *GCP) DownloadPricingData() error {
 	gcp.DownloadPricingDataLock.Lock()
 	defer gcp.DownloadPricingDataLock.Unlock()
-	c, err := GetCustomPricingData("gcp.json")
+	c, err := gcp.Config.GetCustomPricingData()
 	if err != nil {
 		klog.V(2).Infof("Error downloading default pricing data: %s", err.Error())
 		return err
@@ -760,8 +742,8 @@ func (gcp *GCP) PVPricing(pvk PVKey) (*PV, error) {
 }
 
 // Stubbed NetworkPricing for GCP. Pull directly from gcp.json for now
-func (c *GCP) NetworkPricing() (*Network, error) {
-	cpricing, err := GetCustomPricingData("gcp.json")
+func (gcp *GCP) NetworkPricing() (*Network, error) {
+	cpricing, err := gcp.Config.GetCustomPricingData()
 	if err != nil {
 		return nil, err
 	}

+ 4 - 150
cloud/provider.go

@@ -2,16 +2,12 @@ package cloud
 
 import (
 	"database/sql"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"net/url"
 	"os"
-	"reflect"
 	"strings"
-	"sync"
 
 	"k8s.io/klog"
 
@@ -34,9 +30,6 @@ var createTableStatements = []string{
 	);`,
 }
 
-// This Mutex is used to control read/writes to our default config file
-var configLock sync.Mutex
-
 // ReservedInstanceData keeps record of resources on a node should be
 // priced at reserved rates
 type ReservedInstanceData struct {
@@ -210,120 +203,6 @@ func CustomPricesEnabled(p Provider) bool {
 	return config.CustomPricesEnabled == "true"
 }
 
-// DefaultPricing should be returned so we can do computation even if no file is supplied.
-func DefaultPricing() *CustomPricing {
-	return &CustomPricing{
-		Provider:              "base",
-		Description:           "Default prices based on GCP us-central1",
-		CPU:                   "0.031611",
-		SpotCPU:               "0.006655",
-		RAM:                   "0.004237",
-		SpotRAM:               "0.000892",
-		GPU:                   "0.95",
-		Storage:               "0.00005479452",
-		ZoneNetworkEgress:     "0.01",
-		RegionNetworkEgress:   "0.01",
-		InternetNetworkEgress: "0.12",
-		CustomPricesEnabled:   "false",
-	}
-}
-
-// GetDefaultPricingData will search for a json file representing pricing data in /models/ and use it for base pricing info.
-func GetCustomPricingData(fname string) (*CustomPricing, error) {
-	configLock.Lock()
-	defer configLock.Unlock()
-
-	path := configPathFor(fname)
-
-	exists, err := fileExists(path)
-	// File Error other than NotExists
-	if err != nil {
-		klog.Infof("Custom Pricing file at path '%s' read error: '%s'", path, err.Error())
-		return DefaultPricing(), err
-	}
-
-	// File Doesn't Exist
-	if !exists {
-		klog.Infof("Could not find Custom Pricing file at path '%s'", path)
-		c := DefaultPricing()
-		cj, err := json.Marshal(c)
-		if err != nil {
-			return c, err
-		}
-
-		err = ioutil.WriteFile(path, cj, 0644)
-		if err != nil {
-			klog.Infof("Could not write Custom Pricing file to path '%s'", path)
-			return c, err
-		}
-
-		return c, nil
-	}
-
-	// File Exists - Read all contents of file, unmarshal json
-	byteValue, err := ioutil.ReadFile(path)
-	if err != nil {
-		klog.Infof("Could not read Custom Pricing file at path %s", path)
-		return DefaultPricing(), err
-	}
-
-	var customPricing CustomPricing
-	err = json.Unmarshal(byteValue, &customPricing)
-	if err != nil {
-		klog.Infof("Could not decode Custom Pricing file at path %s", path)
-		return DefaultPricing(), err
-	}
-
-	return &customPricing, nil
-}
-
-func configmapUpdate(c *CustomPricing, path string, a map[string]string) (*CustomPricing, error) {
-	for k, v := range a {
-		kUpper := strings.Title(k) // Just so we consistently supply / receive the same values, uppercase the first letter.
-		err := SetCustomPricingField(c, kUpper, v)
-		if err != nil {
-			return nil, err
-		}
-	}
-
-	cj, err := json.Marshal(c)
-	if err != nil {
-		return c, err
-	}
-
-	configLock.Lock()
-	err = ioutil.WriteFile(path, cj, 0644)
-	configLock.Unlock()
-
-	if err != nil {
-		return c, err
-	}
-
-	return c, nil
-}
-
-func SetCustomPricingField(obj *CustomPricing, name string, value string) error {
-	structValue := reflect.ValueOf(obj).Elem()
-	structFieldValue := structValue.FieldByName(name)
-
-	if !structFieldValue.IsValid() {
-		return fmt.Errorf("No such field: %s in obj", name)
-	}
-
-	if !structFieldValue.CanSet() {
-		return fmt.Errorf("Cannot set %s field value", name)
-	}
-
-	structFieldType := structFieldValue.Type()
-	val := reflect.ValueOf(value)
-	if structFieldType != val.Type() {
-		return fmt.Errorf("Provided value type didn't match custom pricing field type")
-	}
-
-	structFieldValue.Set(val)
-	return nil
-}
-
 // NewProvider looks at the nodespec or provider metadata server to decide which provider to instantiate.
 func NewProvider(cache clustercache.ClusterCache, apiKey string) (Provider, error) {
 	if metadata.OnGCE() {
@@ -334,6 +213,7 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string) (Provider, erro
 		return &GCP{
 			Clientset: cache,
 			APIKey:    apiKey,
+			Config:    NewProviderConfig("gcp.json"),
 		}, nil
 	}
 
@@ -347,16 +227,19 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string) (Provider, erro
 		klog.V(2).Info("Found ProviderID starting with \"aws\", using AWS Provider")
 		return &AWS{
 			Clientset: cache,
+			Config:    NewProviderConfig("aws.json"),
 		}, nil
 	} else if strings.HasPrefix(provider, "azure") {
 		klog.V(2).Info("Found ProviderID starting with \"azure\", using Azure Provider")
 		return &Azure{
 			Clientset: cache,
+			Config:    NewProviderConfig("azure.json"),
 		}, nil
 	} else {
 		klog.V(2).Info("Unsupported provider, falling back to default")
 		return &CustomProvider{
 			Clientset: cache,
+			Config:    NewProviderConfig("default.json"),
 		}, nil
 	}
 }
@@ -446,32 +329,3 @@ func GetOrCreateClusterMeta(cluster_id, cluster_name string) (string, string, er
 
 	return id, name, nil
 }
-
-// File exists has three different return cases that should be handled:
-//   1. File exists and is not a directory (true, nil)
-//   2. File does not exist (false, nil)
-//   3. File may or may not exist. Error occurred during stat (false, error)
-// The third case represents the scenario where the stat returns an error,
-// but the error isn't relevant to the path. This can happen when the current
-// user doesn't have permission to access the file.
-func fileExists(filename string) (bool, error) {
-	info, err := os.Stat(filename)
-	if err != nil {
-		if os.IsNotExist(err) {
-			return false, nil
-		}
-
-		return false, err
-	}
-
-	return !info.IsDir(), nil
-}
-
-// Returns the configuration directory concatenated with a specific config file name
-func configPathFor(filename string) string {
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	return path + filename
-}

+ 216 - 0
cloud/providerconfig.go

@@ -0,0 +1,216 @@
+package cloud
+
+import (
+	"encoding/json"
+	"fmt"
+	"io/ioutil"
+	"os"
+	"reflect"
+	"strings"
+	"sync"
+
+	"k8s.io/klog"
+)
+
+// ProviderConfig is a utility class that provides a thread-safe configuration
+// storage/cache for all Provider implementations
+type ProviderConfig struct {
+	lock          *sync.Mutex
+	fileName      string
+	configPath    string
+	customPricing *CustomPricing
+}
+
+// Creates a new ProviderConfig instance
+func NewProviderConfig(file string) *ProviderConfig {
+	return &ProviderConfig{
+		lock:          new(sync.Mutex),
+		fileName:      file,
+		configPath:    configPathFor(file),
+		customPricing: nil,
+	}
+}
+
+// Non-ThreadSafe logic to load the config file if a cache does not exist. Flag to write
+// the default config if the config file doesn't exist.
+func (pc *ProviderConfig) loadConfig(writeIfNotExists bool) (*CustomPricing, error) {
+	if pc.customPricing != nil {
+		return pc.customPricing, nil
+	}
+
+	exists, err := fileExists(pc.configPath)
+	// File Error other than NotExists
+	if err != nil {
+		klog.Infof("Custom Pricing file at path '%s' read error: '%s'", pc.configPath, err.Error())
+		return DefaultPricing(), err
+	}
+
+	// File Doesn't Exist
+	if !exists {
+		klog.Infof("Could not find Custom Pricing file at path '%s'", pc.configPath)
+		pc.customPricing = DefaultPricing()
+
+		// Only write the file if flag enabled
+		if writeIfNotExists {
+			cj, err := json.Marshal(pc.customPricing)
+			if err != nil {
+				return pc.customPricing, err
+			}
+
+			err = ioutil.WriteFile(pc.configPath, cj, 0644)
+			if err != nil {
+				klog.Infof("Could not write Custom Pricing file to path '%s'", pc.configPath)
+				return pc.customPricing, err
+			}
+		}
+
+		return pc.customPricing, nil
+	}
+
+	// File Exists - Read all contents of file, unmarshal json
+	byteValue, err := ioutil.ReadFile(pc.configPath)
+	if err != nil {
+		klog.Infof("Could not read Custom Pricing file at path %s", pc.configPath)
+		// If read fails, we don't want to cache default, assuming that the file is valid
+		return DefaultPricing(), err
+	}
+
+	var customPricing CustomPricing
+	err = json.Unmarshal(byteValue, &customPricing)
+	if err != nil {
+		klog.Infof("Could not decode Custom Pricing file at path %s", pc.configPath)
+		return DefaultPricing(), err
+	}
+
+	pc.customPricing = &customPricing
+
+	return pc.customPricing, nil
+}
+
+// ThreadSafe method for retrieving the custom pricing config.
+func (pc *ProviderConfig) GetCustomPricingData() (*CustomPricing, error) {
+	pc.lock.Lock()
+	defer pc.lock.Unlock()
+
+	return pc.loadConfig(true)
+}
+
+// Allows a call to manually update the configuration while maintaining proper thread-safety
+// for read/write methods.
+func (pc *ProviderConfig) Update(updateFunc func(*CustomPricing) error) (*CustomPricing, error) {
+	pc.lock.Lock()
+	defer pc.lock.Unlock()
+
+	// Load Config, set flag to _not_ write if failure to find file.
+	// We're about to write the updated values, so we don't want to double write.
+	c, _ := pc.loadConfig(false)
+
+	// Execute Update - On error, return the in-memory config but don't update cache
+	// explicitly
+	err := updateFunc(c)
+	if err != nil {
+		return c, err
+	}
+
+	// Cache Update (possible the ptr already references the cached value)
+	pc.customPricing = c
+
+	cj, err := json.Marshal(c)
+	if err != nil {
+		return c, err
+	}
+
+	err = ioutil.WriteFile(pc.configPath, cj, 0644)
+
+	if err != nil {
+		return c, err
+	}
+
+	return c, nil
+}
+
+// ThreadSafe update of the config using a string map
+func (pc *ProviderConfig) UpdateFromMap(a map[string]string) (*CustomPricing, error) {
+	// Run our Update() method using SetCustomPricingField logic
+	return pc.Update(func(c *CustomPricing) error {
+		for k, v := range a {
+			// Just so we consistently supply / receive the same values, uppercase the first letter.
+			kUpper := strings.Title(k)
+			err := SetCustomPricingField(c, kUpper, v)
+			if err != nil {
+				return err
+			}
+		}
+
+		return nil
+	})
+}
+
+// DefaultPricing should be returned so we can do computation even if no file is supplied.
+func DefaultPricing() *CustomPricing {
+	return &CustomPricing{
+		Provider:              "base",
+		Description:           "Default prices based on GCP us-central1",
+		CPU:                   "0.031611",
+		SpotCPU:               "0.006655",
+		RAM:                   "0.004237",
+		SpotRAM:               "0.000892",
+		GPU:                   "0.95",
+		Storage:               "0.00005479452",
+		ZoneNetworkEgress:     "0.01",
+		RegionNetworkEgress:   "0.01",
+		InternetNetworkEgress: "0.12",
+		CustomPricesEnabled:   "false",
+	}
+}
+
+func SetCustomPricingField(obj *CustomPricing, name string, value string) error {
+	structValue := reflect.ValueOf(obj).Elem()
+	structFieldValue := structValue.FieldByName(name)
+
+	if !structFieldValue.IsValid() {
+		return fmt.Errorf("No such field: %s in obj", name)
+	}
+
+	if !structFieldValue.CanSet() {
+		return fmt.Errorf("Cannot set %s field value", name)
+	}
+
+	structFieldType := structFieldValue.Type()
+	val := reflect.ValueOf(value)
+	if structFieldType != val.Type() {
+		return fmt.Errorf("Provided value type didn't match custom pricing field type")
+	}
+
+	structFieldValue.Set(val)
+	return nil
+}
+
+// File exists has three different return cases that should be handled:
+//   1. File exists and is not a directory (true, nil)
+//   2. File does not exist (false, nil)
+//   3. File may or may not exist. Error occurred during stat (false, error)
+// The third case represents the scenario where the stat returns an error,
+// but the error isn't relevant to the path. This can happen when the current
+// user doesn't have permission to access the file.
+func fileExists(filename string) (bool, error) {
+	info, err := os.Stat(filename)
+	if err != nil {
+		if os.IsNotExist(err) {
+			return false, nil
+		}
+
+		return false, err
+	}
+
+	return !info.IsDir(), nil
+}
+
+// Returns the configuration directory concatenated with a specific config file name
+func configPathFor(filename string) string {
+	path := os.Getenv("CONFIG_PATH")
+	if path == "" {
+		path = "/models/"
+	}
+	return path + filename
+}