Sfoglia il codice sorgente

Move some of the noise in logic repetition to a provider function. Ensure consistency accross GetCustomPricing() method. Additionally ensure we are locking the same accross all required reads/writes of the config file.

Matt Bolt 6 anni fa
parent
commit
38b1f16f9c
5 ha cambiato i file con 91 aggiunte e 72 eliminazioni
  1. 9 12
      cloud/awsprovider.go
  2. 9 11
      cloud/azureprovider.go
  3. 6 12
      cloud/customprovider.go
  4. 4 7
      cloud/gcpprovider.go
  5. 63 30
      cloud/provider.go

+ 9 - 12
cloud/awsprovider.go

@@ -273,13 +273,10 @@ func (aws *AWS) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing,
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	configPath := path + "aws.json"
-	return configmapUpdate(c, configPath, a)
+
+	return configmapUpdate(c, configPathFor("aws.json"), a)
 }
+
 func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
 	c, err := GetCustomPricingData("aws.json")
 	if err != nil {
@@ -345,11 +342,9 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, er
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	path += "aws.json"
+
+	path := configPathFor("aws.json")
+
 	remoteEnabled := os.Getenv(remoteEnabled)
 	if remoteEnabled == "true" {
 		err = UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
@@ -357,9 +352,11 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, er
 			return nil, err
 		}
 	}
+
 	configLock.Lock()
 	err = ioutil.WriteFile(path, cj, 0644)
-	defer configLock.Unlock()
+	configLock.Unlock()
+
 	if err != nil {
 		return nil, err
 	}

+ 9 - 11
cloud/azureprovider.go

@@ -518,24 +518,18 @@ func (az *Azure) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing,
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	configPath := path + "azure.json"
-	return configmapUpdate(c, configPath, a)
+
+	return configmapUpdate(c, configPathFor("azure.json"), a)
 }
 
 func (az *Azure) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
 	defer az.DownloadPricingData()
+
 	c, err := GetCustomPricingData("azure.json")
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
+
 	a := make(map[string]interface{})
 	err = json.NewDecoder(r).Decode(&a)
 	if err != nil {
@@ -558,10 +552,12 @@ func (az *Azure) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, e
 			c.SharedCosts = sc //todo: support reflection/multiple map fields
 		}
 	}
+
 	cj, err := json.Marshal(c)
 	if err != nil {
 		return nil, err
 	}
+
 	remoteEnabled := os.Getenv(remoteEnabled)
 	if remoteEnabled == "true" {
 		err = UpdateClusterMeta(os.Getenv(clusterIDKey), c.ClusterName)
@@ -570,10 +566,12 @@ func (az *Azure) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, e
 		}
 	}
 
-	configPath := path + "azure.json"
+	configPath := configPathFor("azure.json")
+
 	configLock.Lock()
 	err = ioutil.WriteFile(configPath, cj, 0644)
 	configLock.Unlock()
+
 	if err != nil {
 		return nil, err
 	}

+ 6 - 12
cloud/customprovider.go

@@ -5,7 +5,6 @@ import (
 	"io"
 	"io/ioutil"
 	"net/url"
-	"os"
 	"strconv"
 	"strings"
 	"sync"
@@ -59,12 +58,8 @@ func (cp *CustomProvider) UpdateConfigFromConfigMap(a map[string]string) (*Custo
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	configPath := path + "default.json"
-	return configmapUpdate(c, configPath, a)
+
+	return configmapUpdate(c, configPathFor("default.json"), a)
 }
 
 func (cp *CustomProvider) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
@@ -72,10 +67,7 @@ func (cp *CustomProvider) UpdateConfig(r io.Reader, updateType string) (*CustomP
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
+
 	a := make(map[string]interface{})
 	err = json.NewDecoder(r).Decode(&a)
 	if err != nil {
@@ -104,10 +96,12 @@ func (cp *CustomProvider) UpdateConfig(r io.Reader, updateType string) (*CustomP
 		return nil, err
 	}
 
-	configPath := path + "default.json"
+	configPath := configPathFor("default.json")
+
 	configLock.Lock()
 	err = ioutil.WriteFile(configPath, cj, 0644)
 	configLock.Unlock()
+
 	if err != nil {
 		return nil, err
 	}

+ 4 - 7
cloud/gcpprovider.go

@@ -123,12 +123,8 @@ func (gcp *GCP) UpdateConfigFromConfigMap(a map[string]string) (*CustomPricing,
 	if err != nil {
 		return nil, err
 	}
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
-	}
-	configPath := path + "gcp.json"
-	return configmapUpdate(c, configPath, a)
+
+	return configmapUpdate(c, configPathFor("gcp.json"), a)
 }
 
 func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, error) {
@@ -184,6 +180,7 @@ func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, er
 			}
 		}
 	}
+
 	cj, err := json.Marshal(c)
 	if err != nil {
 		return nil, err
@@ -197,6 +194,7 @@ func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, er
 	}
 
 	configPath := path + "gcp.json"
+
 	configLock.Lock()
 	err = ioutil.WriteFile(configPath, cj, 0644)
 	configLock.Unlock()
@@ -205,7 +203,6 @@ func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, er
 	}
 
 	return c, nil
-
 }
 
 // ExternalAllocations represents tagged assets outside the scope of kubernetes.

+ 63 - 30
cloud/provider.go

@@ -233,29 +233,17 @@ func GetCustomPricingData(fname string) (*CustomPricing, error) {
 	configLock.Lock()
 	defer configLock.Unlock()
 
-	path := os.Getenv("CONFIG_PATH")
-	if path == "" {
-		path = "/models/"
+	path := configPathFor(fname)
+
+	exists, err := fileExists(path)
+	// File Error other than NotExists
+	if err != nil {
+		klog.Infof("Custom Pricing file at path '%s' read error: '%s'", path, err.Error())
+		return DefaultPricing(), err
 	}
-	path += fname
-	if _, err := os.Stat(path); err == nil {
-		jsonFile, err := os.Open(path)
-		if err != nil {
-			return nil, err
-		}
-		defer jsonFile.Close()
-		byteValue, err := ioutil.ReadAll(jsonFile)
-		if err != nil {
-			return nil, err
-		}
-		var customPricing = &CustomPricing{}
-		err = json.Unmarshal([]byte(byteValue), customPricing)
-		if err != nil {
-			klog.Infof("Could not decode Custom Pricing file at path %s", path)
-			return DefaultPricing(), err
-		}
-		return customPricing, nil
-	} else if os.IsNotExist(err) {
+
+	// File Doesn't Exist
+	if !exists {
 		klog.Infof("Could not find Custom Pricing file at path '%s'", path)
 		c := DefaultPricing()
 		cj, err := json.Marshal(c)
@@ -266,13 +254,27 @@ func GetCustomPricingData(fname string) (*CustomPricing, error) {
 		err = ioutil.WriteFile(path, cj, 0644)
 		if err != nil {
 			klog.Infof("Could not write Custom Pricing file to path '%s'", path)
-			return nil, err
+			return c, err
 		}
+
 		return c, nil
-	} else {
-		klog.Infof("Custom Pricing file at path '%s' read error: '%s'", path, err.Error())
+	}
+
+	// File Exists - Read all contents of file, unmarshal json
+	byteValue, err := ioutil.ReadFile(path)
+	if err != nil {
+		klog.Infof("Could not read Custom Pricing file at path %s", path)
 		return DefaultPricing(), err
 	}
+
+	var customPricing CustomPricing
+	err = json.Unmarshal(byteValue, &customPricing)
+	if err != nil {
+		klog.Infof("Could not decode Custom Pricing file at path %s", path)
+		return DefaultPricing(), err
+	}
+
+	return &customPricing, nil
 }
 
 func configmapUpdate(c *CustomPricing, path string, a map[string]string) (*CustomPricing, error) {
@@ -284,17 +286,19 @@ func configmapUpdate(c *CustomPricing, path string, a map[string]string) (*Custo
 		}
 	}
 
-	configLock.Lock()
-	defer configLock.Unlock()
-
 	cj, err := json.Marshal(c)
 	if err != nil {
-		return nil, err
+		return c, err
 	}
+
+	configLock.Lock()
 	err = ioutil.WriteFile(path, cj, 0644)
+	configLock.Unlock()
+
 	if err != nil {
-		return nil, err
+		return c, err
 	}
+
 	return c, nil
 }
 
@@ -442,3 +446,32 @@ func GetOrCreateClusterMeta(cluster_id, cluster_name string) (string, string, er
 
 	return id, name, nil
 }
+
+// File exists has three different return cases that should be handled:
+//   1. File exists and is not a directory (true, nil)
+//   2. File does not exist (false, nil)
+//   3. File may or may not exist. Error occurred during stat (false, error)
+// The third case represents the scenario where the stat returns an error,
+// but the error isn't relevant to the path. This can happen when the current
+// user doesn't have permission to access the file.
+func fileExists(filename string) (bool, error) {
+	info, err := os.Stat(filename)
+	if err != nil {
+		if os.IsNotExist(err) {
+			return false, nil
+		}
+
+		return false, err
+	}
+
+	return !info.IsDir(), nil
+}
+
+// Returns the configuration directory concatenated with a specific config file name
+func configPathFor(filename string) string {
+	path := os.Getenv("CONFIG_PATH")
+	if path == "" {
+		path = "/models/"
+	}
+	return path + filename
+}