Просмотр исходного кода

Merge pull request #2133 from opencost/niko-burndown-228

Protect config pricing against NaN and other invalid numerical values
Niko Kovacevic 2 лет назад
Родитель
Сommit
f99d2e61bd

+ 1 - 1
pkg/cloud/alibaba/provider.go

@@ -732,7 +732,7 @@ func (alibaba *Alibaba) UpdateConfig(r io.Reader, updateType string) (*models.Cu
 				if ok {
 					err := models.SetCustomPricingField(c, kUpper, vstr)
 					if err != nil {
-						return err
+						return fmt.Errorf("error setting custom pricing field: %w", err)
 					}
 				} else {
 					return fmt.Errorf("type error while updating config for %s", kUpper)

+ 1 - 1
pkg/cloud/aws/provider.go

@@ -591,7 +591,7 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPric
 				if ok {
 					err := models.SetCustomPricingField(c, kUpper, vstr)
 					if err != nil {
-						return err
+						return fmt.Errorf("error setting custom pricing field: %w", err)
 					}
 				} else {
 					return fmt.Errorf("type error while updating config for %s", kUpper)

+ 33 - 4
pkg/cloud/models/models.go

@@ -3,6 +3,7 @@ package models
 import (
 	"fmt"
 	"io"
+	"math"
 	"reflect"
 	"strconv"
 	"strings"
@@ -215,24 +216,52 @@ func (cp *CustomPricing) GetSharedOverheadCostPerMonth() float64 {
 	return sharedCostPerMonth
 }
 
-func SetCustomPricingField(obj *CustomPricing, name string, value string) error {
+func sanitizeFloatString(number string, allowNaN bool) (string, error) {
+	num, err := strconv.ParseFloat(number, 64)
+	if err != nil {
+		return "", fmt.Errorf("expected a string representing a number; got '%s'", number)
+	}
+	if !allowNaN && math.IsNaN(num) {
+		return "", fmt.Errorf("expected a string representing a number; got 'NaN'")
+	}
 
+	// Format the numerical string we just parsed.
+	return strconv.FormatFloat(num, 'f', -1, 64), nil
+}
+
+func SetCustomPricingField(obj *CustomPricing, name string, value string) error {
 	structValue := reflect.ValueOf(obj).Elem()
 	structFieldValue := structValue.FieldByName(name)
 
 	if !structFieldValue.IsValid() {
-		return fmt.Errorf("No such field: %s in obj", name)
+		return fmt.Errorf("no such field: %s in obj", name)
 	}
 
 	if !structFieldValue.CanSet() {
-		return fmt.Errorf("Cannot set %s field value", name)
+		return fmt.Errorf("cannot set %s field value", name)
+	}
+
+	// If the custom pricing field is expected to be a string representation
+	// of a floating point number, e.g. a resource price, then do some extra
+	// validation work in order to prevent "NaN" and other invalid strings
+	// from getting set here.
+	switch strings.ToLower(name) {
+	case "cpu", "gpu", "ram", "spotcpu", "spotgpu", "spotram", "storage", "zonenetworkegress", "regionnetworkegress", "internetnetworkegress":
+		// Validate that "value" represents a real floating point number, and
+		// set precision, bits, etc. Do not allow NaN.
+		val, err := sanitizeFloatString(value, false)
+		if err != nil {
+			return fmt.Errorf("invalid numeric value for field '%s': %s", name, value)
+		}
+		value = val
+	default:
 	}
 
 	structFieldType := structFieldValue.Type()
 	value = sanitizePolicy.Sanitize(value)
 	val := reflect.ValueOf(value)
 	if structFieldType != val.Type() {
-		return fmt.Errorf("Provided value type didn't match custom pricing field type")
+		return fmt.Errorf("provided value type didn't match custom pricing field type")
 	}
 
 	structFieldValue.Set(val)

+ 118 - 0
pkg/cloud/models/models_test.go

@@ -0,0 +1,118 @@
+package models
+
+import (
+	"fmt"
+	"reflect"
+	"testing"
+)
+
+func TestSetSetCustomPricingField(t *testing.T) {
+	defaultValue := "1.0"
+
+	type testCase struct {
+		testName   string
+		fieldName  string
+		fieldValue string
+		expValue   string
+		expError   error
+	}
+
+	testCaseTemplates := []testCase{
+		{
+			testName:   "valid number for %s",
+			fieldName:  "%s",
+			fieldValue: "0.04321",
+			expValue:   "0.04321",
+			expError:   nil,
+		},
+		{
+			testName:   "long number for %s",
+			fieldName:  "%s",
+			fieldValue: "0.04321234321231234",
+			expValue:   "0.04321234321231234",
+			expError:   nil,
+		},
+		{
+			testName:   "illegal number for %s",
+			fieldName:  "%s",
+			fieldValue: "0.123.123",
+			expValue:   defaultValue, // assert that the default value is not mutated
+			expError:   fmt.Errorf("invalid numeric value for field"),
+		},
+		{
+			testName:   "NaN for %s",
+			fieldName:  "%s",
+			fieldValue: "NaN",
+			expValue:   defaultValue, // assert that the default value is not mutated
+			expError:   fmt.Errorf("invalid numeric value for field"),
+		},
+		{
+			testName:   "empty string for %s",
+			fieldName:  "%s",
+			fieldValue: "",
+			expValue:   defaultValue, // assert that the default value is not mutated
+			expError:   fmt.Errorf("invalid numeric value for field"),
+		},
+	}
+
+	numericFields := []string{
+		"CPU",
+		"GPU",
+		"RAM",
+		"SpotCPU",
+		"SpotGPU",
+		"SpotRAM",
+		"Storage",
+		"ZoneNetworkEgress",
+		"RegionNetworkEgress",
+		"InternetNetworkEgress",
+	}
+
+	testCases := []testCase{}
+
+	// Build one test case per-template, per-numeric field; this is obscure
+	// but it prevents me from having to write the same test for all 10
+	// numeric fields...
+	for _, field := range numericFields {
+		for _, tpl := range testCaseTemplates {
+			testCases = append(testCases, testCase{
+				testName:   fmt.Sprintf(tpl.testName, field),
+				fieldName:  fmt.Sprintf(tpl.fieldName, field),
+				fieldValue: tpl.fieldValue,
+				expValue:   tpl.expValue,
+				expError:   tpl.expError,
+			})
+		}
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.testName, func(t *testing.T) {
+			cp := &CustomPricing{
+				CPU:                   defaultValue,
+				SpotCPU:               defaultValue,
+				RAM:                   defaultValue,
+				SpotRAM:               defaultValue,
+				GPU:                   defaultValue,
+				SpotGPU:               defaultValue,
+				Storage:               defaultValue,
+				ZoneNetworkEgress:     defaultValue,
+				RegionNetworkEgress:   defaultValue,
+				InternetNetworkEgress: defaultValue,
+			}
+			err := SetCustomPricingField(cp, tc.fieldName, tc.fieldValue)
+			if err != nil && tc.expError == nil {
+				t.Errorf("unexpected error: %s", err)
+			}
+			if err == nil && tc.expError != nil {
+				t.Errorf("did not find expected error: %s", tc.expError)
+			}
+
+			structValue := reflect.ValueOf(cp).Elem()
+			structFieldValue := structValue.FieldByName(tc.fieldName)
+			actValue := structFieldValue.String()
+			if actValue != tc.expValue {
+				t.Errorf("expected field '%s' to be '%s'; actual value is '%s'", tc.fieldName, tc.expValue, actValue)
+			}
+		})
+	}
+}

+ 1 - 1
pkg/cloud/provider/customprovider.go

@@ -119,7 +119,7 @@ func (cp *CustomProvider) UpdateConfig(r io.Reader, updateType string) (*models.
 			if ok {
 				err := models.SetCustomPricingField(c, kUpper, vstr)
 				if err != nil {
-					return err
+					return fmt.Errorf("error setting custom pricing field: %w", err)
 				}
 			} else {
 				return fmt.Errorf("type error while updating config for %s", kUpper)

+ 6 - 6
pkg/cloud/provider/providerconfig.go

@@ -181,7 +181,7 @@ func (pc *ProviderConfig) Update(updateFunc func(*models.CustomPricing) error) (
 	// explicitly
 	err := updateFunc(c)
 	if err != nil {
-		return c, err
+		return c, fmt.Errorf("error updating provider config: %w", err)
 	}
 
 	// Cache Update (possible the ptr already references the cached value)
@@ -189,12 +189,12 @@ func (pc *ProviderConfig) Update(updateFunc func(*models.CustomPricing) error) (
 
 	cj, err := json.Marshal(c)
 	if err != nil {
-		return c, err
+		return c, fmt.Errorf("error marshaling JSON for provider config: %w", err)
 	}
-	err = pc.configFile.Write(cj)
 
+	err = pc.configFile.Write(cj)
 	if err != nil {
-		return c, err
+		return c, fmt.Errorf("error writing config file for provider config: %w", err)
 	}
 
 	return c, nil
@@ -210,14 +210,14 @@ func (pc *ProviderConfig) UpdateFromMap(a map[string]string) (*models.CustomPric
 			if kUpper == "CPU" || kUpper == "SpotCPU" || kUpper == "RAM" || kUpper == "SpotRAM" || kUpper == "GPU" || kUpper == "Storage" {
 				val, err := strconv.ParseFloat(v, 64)
 				if err != nil {
-					return fmt.Errorf("Unable to parse CPU from string to float: %s", err.Error())
+					return fmt.Errorf("unable to parse CPU from string to float: %s", err.Error())
 				}
 				v = fmt.Sprintf("%f", val/730)
 			}
 
 			err := models.SetCustomPricingField(c, kUpper, v)
 			if err != nil {
-				return err
+				return fmt.Errorf("error setting custom pricing field: %w", err)
 			}
 		}
 

+ 1 - 1
pkg/cloud/scaleway/provider.go

@@ -317,7 +317,7 @@ func (c *Scaleway) UpdateConfig(r io.Reader, updateType string) (*models.CustomP
 			if ok {
 				err := models.SetCustomPricingField(c, kUpper, vstr)
 				if err != nil {
-					return err
+					return fmt.Errorf("error setting custom pricing field: %w", err)
 				}
 			} else {
 				return fmt.Errorf("type error while updating config for %s", kUpper)

+ 15 - 0
pkg/costmodel/allocation_helpers.go

@@ -1771,6 +1771,21 @@ func (cm *CostModel) getNodePricing(nodeMap map[nodeKey]*nodePricing, nodeKey no
 		node.Source += "/customRAM"
 	}
 
+	// Double check each for NaNs, as there is a chance that our custom pricing
+	// config could, itself, contain NaNs...
+	if math.IsNaN(node.CostPerCPUHr) || math.IsInf(node.CostPerCPUHr, 0) {
+		log.Warnf("CostModel: %s: node pricing has illegal CPU value: %v (setting to 0.0)", nodeKey, node.CostPerCPUHr)
+		node.CostPerCPUHr = 0.0
+	}
+	if math.IsNaN(node.CostPerGPUHr) || math.IsInf(node.CostPerGPUHr, 0) {
+		log.Warnf("CostModel: %s: node pricing has illegal RAM value: %v (setting to 0.0)", nodeKey, node.CostPerGPUHr)
+		node.CostPerGPUHr = 0.0
+	}
+	if math.IsNaN(node.CostPerRAMGiBHr) || math.IsInf(node.CostPerRAMGiBHr, 0) {
+		log.Warnf("CostModel: %s: node pricing has illegal RAM value: %v (setting to 0.0)", nodeKey, node.CostPerRAMGiBHr)
+		node.CostPerRAMGiBHr = 0.0
+	}
+
 	return node
 }