Quellcode durchsuchen

fix csv casing, precedence, add test

Ajay Tripathy vor 2 Jahren
Ursprung
Commit
740fab0435
3 geänderte Dateien mit 93 neuen und 30 gelöschten Zeilen
  1. 5 0
      configs/pricing_schema_mixed_gpu_ondemand.csv
  2. 35 30
      pkg/cloud/provider/csvprovider.go
  3. 53 0
      test/cloud_test.go

+ 5 - 0
configs/pricing_schema_mixed_gpu_ondemand.csv

@@ -0,0 +1,5 @@
+EndTimestamp,InstanceID,Region,AssetClass,InstanceIDField,InstanceType,MarketPriceHourly,Version
+2028-01-06 23:34:45 UTC,Reserved,,node,metadata.labels.TestClusterUsage,,0.654795,
+2028-01-06 23:34:45 UTC,OnDemand,,node,metadata.labels.TestClusterUsage,,0.90411,
+2028-01-06 23:34:45 UTC,a100-ondemand,,gpu,nvidia.com/gpu_type,,0.5,
+2028-01-06 23:34:45 UTC,a100-reserved,,gpu,nvidia.com/gpu_type,,1,

+ 35 - 30
pkg/cloud/provider/csvprovider.go

@@ -228,60 +228,65 @@ func (k *csvKey) ID() string {
 	return k.ProviderID
 }
 
-func (c *CSVProvider) NodePricing(key models.Key) (*models.Node, models.PricingMetadata, error) {
-	c.DownloadPricingDataLock.RLock()
-	defer c.DownloadPricingDataLock.RUnlock()
-	meta := models.PricingMetadata{}
-	var node *models.Node
+func (c *CSVProvider) nodePricing(key models.Key) *models.Node {
 	if p, ok := c.Pricing[key.ID()]; ok {
-		node = &models.Node{
+		return &models.Node{
 			Cost:        p.MarketPriceHourly,
 			PricingType: models.CsvExact,
 		}
 	}
+
 	s := strings.Split(key.ID(), ",") // Try without a region to be sure
 	if len(s) == 2 {
 		if p, ok := c.Pricing[s[1]]; ok {
-			node = &models.Node{
+			return &models.Node{
 				Cost:        p.MarketPriceHourly,
 				PricingType: models.CsvExact,
 			}
 		}
 	}
+
 	classKey := key.Features() // Use node attributes to try and do a class match
 	if cost, ok := c.NodeClassPricing[classKey]; ok {
 		log.Infof("Unable to find provider ID `%s`, using features:`%s`", key.ID(), key.Features())
-		node = &models.Node{
+		return &models.Node{
 			Cost:        fmt.Sprintf("%f", cost),
 			PricingType: models.CsvClass,
 		}
 	}
 
-	if node != nil {
-		if t := key.GPUType(); t != "" {
-			t = strings.ToLower(t)
-			count := key.GPUCount()
-			node.GPU = strconv.Itoa(count)
-			hourly := 0.0
-			if p, ok := c.GPUClassPricing[t]; ok {
-				var err error
-				hourly, err = strconv.ParseFloat(p.MarketPriceHourly, 64)
-				if err != nil {
-					log.Errorf("Unable to parse %s as float", p.MarketPriceHourly)
-				}
-			}
-			totalCost := hourly * float64(count)
-			node.GPUCost = fmt.Sprintf("%f", totalCost)
-			nc, err := strconv.ParseFloat(node.Cost, 64)
+	return nil
+}
+
+func (c *CSVProvider) NodePricing(key models.Key) (*models.Node, models.PricingMetadata, error) {
+	c.DownloadPricingDataLock.RLock()
+	defer c.DownloadPricingDataLock.RUnlock()
+
+	node := c.nodePricing(key)
+	if node == nil {
+		return nil, models.PricingMetadata{}, fmt.Errorf("Unable to find Node matching `%s`:`%s`", key.ID(), key.Features())
+	}
+	if t := key.GPUType(); t != "" {
+		t = strings.ToLower(t)
+		count := key.GPUCount()
+		node.GPU = strconv.Itoa(count)
+		hourly := 0.0
+		if p, ok := c.GPUClassPricing[t]; ok {
+			var err error
+			hourly, err = strconv.ParseFloat(p.MarketPriceHourly, 64)
 			if err != nil {
-				log.Errorf("Unable to parse %s as float", node.Cost)
+				log.Errorf("Unable to parse %s as float", p.MarketPriceHourly)
 			}
-			node.Cost = fmt.Sprintf("%f", nc+totalCost)
 		}
-		return node, meta, nil
-	} else {
-		return nil, meta, fmt.Errorf("Unable to find Node matching `%s`:`%s`", key.ID(), key.Features())
+		totalCost := hourly * float64(count)
+		node.GPUCost = fmt.Sprintf("%f", totalCost)
+		nc, err := strconv.ParseFloat(node.Cost, 64)
+		if err != nil {
+			log.Errorf("Unable to parse %s as float", node.Cost)
+		}
+		node.Cost = fmt.Sprintf("%f", nc+totalCost)
 	}
+	return node, models.PricingMetadata{}, nil
 }
 
 func NodeValueFromMapField(m string, n *v1.Node, useRegion bool) string {
@@ -368,7 +373,7 @@ func (c *CSVProvider) GetKey(l map[string]string, n *v1.Node) models.Key {
 		gpuCount = gpuc.Value()
 	}
 	return &csvKey{
-		ProviderID: id,
+		ProviderID: strings.ToLower(id),
 		Labels:     l,
 		GPULabel:   c.GPUMapFields,
 		GPU:        gpuCount,

+ 53 - 0
test/cloud_test.go

@@ -580,6 +580,59 @@ func TestNodePriceFromCSVWithCase(t *testing.T) {
 
 }
 
+func TestNodePriceFromCSVMixed(t *testing.T) {
+	labelFooWant := "OnDemand"
+
+	confMan := config.NewConfigFileManager(&config.ConfigFileManagerOpts{
+		LocalConfigPath: "./",
+	})
+
+	n := &v1.Node{}
+	n.Labels = make(map[string]string)
+	n.Labels["TestClusterUsage"] = labelFooWant
+	n.Labels["nvidia.com/gpu_type"] = "a100-ondemand"
+	n.Status.Capacity = v1.ResourceList{"nvidia.com/gpu": *resource.NewScaledQuantity(2, 0)}
+	wantPrice := "1.904110"
+
+	labelFooWant2 := "Reserved"
+	n2 := &v1.Node{}
+	n2.Labels = make(map[string]string)
+	n2.Labels["TestClusterUsage"] = labelFooWant2
+	n2.Labels["nvidia.com/gpu_type"] = "a100-reserved"
+	n2.Status.Capacity = v1.ResourceList{"nvidia.com/gpu": *resource.NewScaledQuantity(1, 0)}
+
+	wantPrice2 := "1.654795"
+
+	c := &provider.CSVProvider{
+		CSVLocation: "../configs/pricing_schema_mixed_gpu_ondemand.csv",
+		CustomProvider: &provider.CustomProvider{
+			Config: provider.NewProviderConfig(confMan, "../configs/default.json"),
+		},
+	}
+	c.DownloadPricingData()
+	k := c.GetKey(n.Labels, n)
+	resN, _, err := c.NodePricing(k)
+	if err != nil {
+		t.Errorf("Error in NodePricing: %s", err.Error())
+	} else {
+		gotPrice := resN.Cost
+		if gotPrice != wantPrice {
+			t.Errorf("Wanted price '%s' got price '%s'", wantPrice, gotPrice)
+		}
+	}
+	k2 := c.GetKey(n2.Labels, n2)
+	resN2, _, err2 := c.NodePricing(k2)
+	if err2 != nil {
+		t.Errorf("Error in NodePricing: %s", err.Error())
+	} else {
+		gotPrice := resN2.Cost
+		if gotPrice != wantPrice2 {
+			t.Errorf("Wanted price '%s' got price '%s'", wantPrice2, gotPrice)
+		}
+	}
+
+}
+
 func TestNodePriceFromCSVByClass(t *testing.T) {
 	n := &v1.Node{}
 	n.Spec.ProviderID = "fakeproviderid"