Просмотр исходного кода

fix(azure): make ondemand pricing OS-aware for Windows nodes (#3820)

Signed-off-by: Varshith <kvarshithgowda@gmail.com>
Varshith 13 часов назад
Родитель
Сommit
1741baaf36
2 измененных файлов с 283 добавлено и 39 удалено
  1. 74 26
      pkg/cloud/azure/provider.go
  2. 209 13
      pkg/cloud/azure/provider_test.go

+ 74 - 26
pkg/cloud/azure/provider.go

@@ -274,32 +274,39 @@ func buildAzureRetailPricesURL(region string, skuName string, currencyCode strin
 	return pricingURL
 	return pricingURL
 }
 }
 
 
-func extractAzureVMRetailAndSpotPrices(resp *http.Response) (retailPrice string, spotPrice string, err error) {
+func extractAzureVMRetailAndSpotPrices(resp *http.Response) (linuxRetailPrice string, windowsRetailPrice string, spotPrice string, windowsSpotPrice string, err error) {
 	body, err := io.ReadAll(resp.Body)
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
-		return "", "", fmt.Errorf("Error getting response: %v", err)
+		return "", "", "", "", fmt.Errorf("error getting response: %w", err)
 	}
 	}
 
 
 	pricingPayload := AzureRetailPricing{}
 	pricingPayload := AzureRetailPricing{}
 	jsonErr := json.Unmarshal(body, &pricingPayload)
 	jsonErr := json.Unmarshal(body, &pricingPayload)
 	if jsonErr != nil {
 	if jsonErr != nil {
-		return "", "", fmt.Errorf("error unmarshalling data: %v", jsonErr)
+		return "", "", "", "", fmt.Errorf("error unmarshalling data: %w", jsonErr)
 	}
 	}
 	for _, item := range pricingPayload.Items {
 	for _, item := range pricingPayload.Items {
-		// note: Windows OS ondemand price will be equal to Linux, Adoption of Windows based
-		// computes are increasing in Azure we might want to enhance this in future.
-		if !strings.Contains(item.ProductName, "Windows") {
-			if strings.Contains(strings.ToLower(item.SkuName), " spot") {
+		skuLower := strings.ToLower(item.SkuName)
+		productLower := strings.ToLower(item.ProductName)
+		isWindowsProduct := strings.Contains(productLower, "windows")
+		if strings.Contains(skuLower, " spot") {
+			if isWindowsProduct {
+				windowsSpotPrice = fmt.Sprintf("%f", item.RetailPrice)
+			} else {
 				spotPrice = fmt.Sprintf("%f", item.RetailPrice)
 				spotPrice = fmt.Sprintf("%f", item.RetailPrice)
-			} else if !(strings.Contains(strings.ToLower(item.SkuName), "low priority") || strings.Contains(strings.ToLower(item.ProductName), "cloud services") || strings.Contains(strings.ToLower(item.ProductName), "cloudservices")) {
-				retailPrice = fmt.Sprintf("%f", item.RetailPrice)
+			}
+		} else if !(strings.Contains(skuLower, "low priority") || strings.Contains(productLower, "cloud services") || strings.Contains(productLower, "cloudservices")) {
+			if isWindowsProduct {
+				windowsRetailPrice = fmt.Sprintf("%f", item.RetailPrice)
+			} else {
+				linuxRetailPrice = fmt.Sprintf("%f", item.RetailPrice)
 			}
 			}
 		}
 		}
 	}
 	}
-	return retailPrice, spotPrice, nil
+	return linuxRetailPrice, windowsRetailPrice, spotPrice, windowsSpotPrice, nil
 }
 }
 
 
-func getRetailPrice(region string, skuName string, currencyCode string, spot bool) (string, error) {
+func getRetailPrice(region string, skuName string, currencyCode string, spot bool, isWindows bool) (string, error) {
 	pricingURL := buildAzureRetailPricesURL(region, skuName, currencyCode)
 	pricingURL := buildAzureRetailPricesURL(region, skuName, currencyCode)
 	log.Infof("starting download retail price payload from \"%s\"", pricingURL)
 	log.Infof("starting download retail price payload from \"%s\"", pricingURL)
 
 
@@ -308,29 +315,51 @@ func getRetailPrice(region string, skuName string, currencyCode string, spot boo
 	client := httputil.BoundedClient()
 	client := httputil.BoundedClient()
 	resp, err := client.Get(pricingURL)
 	resp, err := client.Get(pricingURL)
 	if err != nil {
 	if err != nil {
-		return "", fmt.Errorf("failed to fetch retail price with URL \"%s\": %v", pricingURL, err)
+		return "", fmt.Errorf("failed to fetch retail price with URL \"%s\": %w", pricingURL, err)
 	}
 	}
 
 
 	if resp.StatusCode < 200 && resp.StatusCode > 299 {
 	if resp.StatusCode < 200 && resp.StatusCode > 299 {
 		return "", fmt.Errorf("retail price responded with error status code %d", resp.StatusCode)
 		return "", fmt.Errorf("retail price responded with error status code %d", resp.StatusCode)
 	}
 	}
 
 
-	retailPrice, spotPrice, err := extractAzureVMRetailAndSpotPrices(resp)
+	linuxRetailPrice, windowsRetailPrice, spotPrice, windowsSpotPrice, err := extractAzureVMRetailAndSpotPrices(resp)
 	if err != nil {
 	if err != nil {
-		return "", fmt.Errorf("failed to extract azure prices: %v", err)
+		return "", fmt.Errorf("failed to extract azure prices: %w", err)
 	}
 	}
 
 
 	log.DedupedInfof(5, "done parsing retail price payload from \"%s\"\n", pricingURL)
 	log.DedupedInfof(5, "done parsing retail price payload from \"%s\"\n", pricingURL)
 
 
-	if spot && spotPrice != "" {
-		return spotPrice, nil
+	return selectRetailPrice(region, skuName, linuxRetailPrice, windowsRetailPrice, spotPrice, windowsSpotPrice, spot, isWindows)
+}
+
+// selectRetailPrice picks the price matching the node OS and pricing model.
+// Windows nodes prefer the Windows-specific price; when it is absent the Linux
+// price is used as a best-effort estimate and the fallback is logged so the
+// substitution is not silent.
+func selectRetailPrice(region, skuName, linuxRetailPrice, windowsRetailPrice, spotPrice, windowsSpotPrice string, spot, isWindows bool) (string, error) {
+	if spot {
+		if isWindows && windowsSpotPrice != "" {
+			return windowsSpotPrice, nil
+		}
+		if spotPrice != "" {
+			if isWindows {
+				log.Warnf("no Windows spot price for %q in %q region; falling back to Linux spot price", skuName, region)
+			}
+			return spotPrice, nil
+		}
 	}
 	}
 
 
-	if retailPrice == "" {
-		return retailPrice, fmt.Errorf("Couldn't find price for product \"%s\" in \"%s\" region", skuName, region)
+	selectedRetail := linuxRetailPrice
+	if isWindows && windowsRetailPrice != "" {
+		selectedRetail = windowsRetailPrice
+	} else if isWindows && linuxRetailPrice != "" {
+		log.Warnf("no Windows retail price for %q in %q region; falling back to Linux retail price", skuName, region)
+	}
+	if selectedRetail == "" {
+		return "", fmt.Errorf("couldn't find price for product %q in %q region", skuName, region)
 	}
 	}
 
 
-	return retailPrice, nil
+	return selectedRetail, nil
 }
 }
 
 
 func toRegionID(meterRegion string, regions map[string]string) (string, error) {
 func toRegionID(meterRegion string, regions map[string]string) (string, error) {
@@ -444,6 +473,17 @@ func (az *Azure) PricingSourceSummary() interface{} {
 	return az.Pricing
 	return az.Pricing
 }
 }
 
 
+// azureWindowsOS is the node OS label value that identifies a Windows node and
+// the suffix used to qualify Windows-specific pricing keys.
+const azureWindowsOS = "windows"
+
+// isWindowsNode reports whether the node labels identify a Windows node. It
+// centralizes the OS detection shared by azureKey.Features and NodePricing.
+func isWindowsNode(labels map[string]string) bool {
+	osLabel, ok := util.GetOperatingSystem(labels)
+	return ok && strings.ToLower(osLabel) == azureWindowsOS
+}
+
 type azureKey struct {
 type azureKey struct {
 	Labels        map[string]string
 	Labels        map[string]string
 	GPULabel      string
 	GPULabel      string
@@ -455,6 +495,9 @@ func (k *azureKey) Features() string {
 	region := strings.ToLower(r)
 	region := strings.ToLower(r)
 	instance, _ := util.GetInstanceType(k.Labels)
 	instance, _ := util.GetInstanceType(k.Labels)
 	usageType := "ondemand"
 	usageType := "ondemand"
+	if isWindowsNode(k.Labels) {
+		return fmt.Sprintf("%s,%s,%s,%s", region, instance, usageType, azureWindowsOS)
+	}
 	return fmt.Sprintf("%s,%s,%s", region, instance, usageType)
 	return fmt.Sprintf("%s,%s,%s", region, instance, usageType)
 }
 }
 
 
@@ -996,10 +1039,7 @@ func convertMeterToPricings(info commerce.MeterInfo, regions map[string]string,
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	if strings.Contains(meterSubCategory, "Windows") {
-		// This meter doesn't correspond to any pricings.
-		return nil, nil
-	}
+	isWindowsMeter := strings.Contains(meterSubCategory, "Windows")
 
 
 	if strings.Contains(meterSubCategory, "Cloud Services") || strings.Contains(meterSubCategory, "CloudServices") {
 	if strings.Contains(meterSubCategory, "Cloud Services") || strings.Contains(meterSubCategory, "CloudServices") {
 		// This meter doesn't correspond to any pricings.
 		// This meter doesn't correspond to any pricings.
@@ -1083,8 +1123,10 @@ func convertMeterToPricings(info commerce.MeterInfo, regions map[string]string,
 	priceStr := fmt.Sprintf("%f", priceInUsd)
 	priceStr := fmt.Sprintf("%f", priceInUsd)
 	results := make(map[string]*AzurePricing)
 	results := make(map[string]*AzurePricing)
 	for _, instanceType := range instanceTypes {
 	for _, instanceType := range instanceTypes {
-
 		key := fmt.Sprintf("%s,%s,%s", region, instanceType, usageType)
 		key := fmt.Sprintf("%s,%s,%s", region, instanceType, usageType)
+		if isWindowsMeter {
+			key = fmt.Sprintf("%s,%s,%s,%s", region, instanceType, usageType, azureWindowsOS)
+		}
 		pricing := &AzurePricing{
 		pricing := &AzurePricing{
 			Node: &models.Node{
 			Node: &models.Node{
 				Cost:         priceStr,
 				Cost:         priceStr,
@@ -1169,12 +1211,18 @@ func (az *Azure) NodePricing(key models.Key) (*models.Node, models.PricingMetada
 	slv, ok := azKey.Labels[config.SpotLabel]
 	slv, ok := azKey.Labels[config.SpotLabel]
 	isSpot := ok && slv == config.SpotLabelValue && config.SpotLabel != "" && config.SpotLabelValue != ""
 	isSpot := ok && slv == config.SpotLabelValue && config.SpotLabel != "" && config.SpotLabelValue != ""
 
 
+	isWindows := isWindowsNode(azKey.Labels)
+
 	features := strings.Split(azKey.Features(), ",")
 	features := strings.Split(azKey.Features(), ",")
 	region := features[0]
 	region := features[0]
 	instance := features[1]
 	instance := features[1]
 	var featureString string
 	var featureString string
 	if isSpot {
 	if isSpot {
-		featureString = fmt.Sprintf("%s,%s,spot", region, instance)
+		if isWindows {
+			featureString = fmt.Sprintf("%s,%s,spot,%s", region, instance, azureWindowsOS)
+		} else {
+			featureString = fmt.Sprintf("%s,%s,spot", region, instance)
+		}
 	} else {
 	} else {
 		featureString = azKey.Features()
 		featureString = azKey.Features()
 	}
 	}
@@ -1191,7 +1239,7 @@ func (az *Azure) NodePricing(key models.Key) (*models.Node, models.PricingMetada
 		}
 		}
 	}
 	}
 
 
-	cost, err := getRetailPrice(region, instance, config.CurrencyCode, isSpot)
+	cost, err := getRetailPrice(region, instance, config.CurrencyCode, isSpot, isWindows)
 
 
 	if err != nil {
 	if err != nil {
 		log.DedupedWarningf(5, "failed to retrieve retail pricing: %s", err)
 		log.DedupedWarningf(5, "failed to retrieve retail pricing: %s", err)

+ 209 - 13
pkg/cloud/azure/provider_test.go

@@ -69,7 +69,13 @@ func TestConvertMeterToPricings(t *testing.T) {
 		info := meterInfo("Virtual Machines", "D2 Series Windows", "D2s v3", "AU Southeast", 0.3)
 		info := meterInfo("Virtual Machines", "D2 Series Windows", "D2s v3", "AU Southeast", 0.3)
 		results, err := convertMeterToPricings(info, regions, baseCPUPrice)
 		results, err := convertMeterToPricings(info, regions, baseCPUPrice)
 		require.NoError(t, err)
 		require.NoError(t, err)
-		require.Nil(t, results)
+		key := "australiasoutheast,Standard_D2s_v3,ondemand,windows"
+		pricing, ok := results[key]
+		require.Truef(t, ok, "expected a pricing entry under key %q", key)
+		require.NotNil(t, pricing.Node)
+		require.Equal(t, "ondemand", pricing.Node.UsageType)
+		require.Equal(t, "0.300000", pricing.Node.Cost)
+		require.Equal(t, baseCPUPrice, pricing.Node.BaseCPUPrice)
 	})
 	})
 
 
 	t.Run("storage", func(t *testing.T) {
 	t.Run("storage", func(t *testing.T) {
@@ -102,6 +108,86 @@ func TestConvertMeterToPricings(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestSelectRetailPrice(t *testing.T) {
+	cases := []struct {
+		name               string
+		linuxRetailPrice   string
+		windowsRetailPrice string
+		spotPrice          string
+		windowsSpotPrice   string
+		spot               bool
+		isWindows          bool
+		expected           string
+		expectErr          bool
+	}{
+		{
+			name:               "windows retail prefers windows price",
+			linuxRetailPrice:   "1.000000",
+			windowsRetailPrice: "2.000000",
+			isWindows:          true,
+			expected:           "2.000000",
+		},
+		{
+			name:             "windows retail falls back to linux when windows missing",
+			linuxRetailPrice: "1.000000",
+			isWindows:        true,
+			expected:         "1.000000",
+		},
+		{
+			name:             "linux retail uses linux price",
+			linuxRetailPrice: "1.000000",
+			isWindows:        false,
+			expected:         "1.000000",
+		},
+		{
+			name:             "windows spot prefers windows spot price",
+			spotPrice:        "0.500000",
+			windowsSpotPrice: "0.900000",
+			spot:             true,
+			isWindows:        true,
+			expected:         "0.900000",
+		},
+		{
+			name:      "windows spot falls back to linux spot when windows missing",
+			spotPrice: "0.500000",
+			spot:      true,
+			isWindows: true,
+			expected:  "0.500000",
+		},
+		{
+			name:      "linux spot uses linux spot price",
+			spotPrice: "0.500000",
+			spot:      true,
+			isWindows: false,
+			expected:  "0.500000",
+		},
+		{
+			name:               "spot windows with no spot price falls back to retail",
+			windowsRetailPrice: "2.000000",
+			spot:               true,
+			isWindows:          true,
+			expected:           "2.000000",
+		},
+		{
+			name:      "no price available returns error",
+			isWindows: true,
+			expectErr: true,
+		},
+	}
+
+	for _, tc := range cases {
+		t.Run(tc.name, func(t *testing.T) {
+			got, err := selectRetailPrice("eastus", "Standard_D2s_v3", tc.linuxRetailPrice, tc.windowsRetailPrice, tc.spotPrice, tc.windowsSpotPrice, tc.spot, tc.isWindows)
+			if tc.expectErr {
+				require.Error(t, err)
+				return
+			}
+			require.NoError(t, err)
+			require.Equal(t, tc.expected, got)
+		})
+	}
+}
+
 func TestAzure_findCostForDisk(t *testing.T) {
 func TestAzure_findCostForDisk(t *testing.T) {
 	var loc string = "location"
 	var loc string = "location"
 	var size int32 = 1
 	var size int32 = 1
@@ -390,14 +476,76 @@ func Test_buildAzureRetailPricesURL(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestAzureKeyFeaturesOS(t *testing.T) {
+	tests := []struct {
+		name     string
+		labels   map[string]string
+		expected string
+	}{
+		{
+			name: "windows node via kubernetes.io/os",
+			labels: map[string]string{
+				"kubernetes.io/os":                 "windows",
+				"node.kubernetes.io/instance-type": "Standard_D4s_v3",
+				"topology.kubernetes.io/region":    "eastus",
+			},
+			expected: "eastus,Standard_D4s_v3,ondemand,windows",
+		},
+		{
+			name: "windows node via beta.kubernetes.io/os",
+			labels: map[string]string{
+				"beta.kubernetes.io/os":            "windows",
+				"node.kubernetes.io/instance-type": "Standard_D4s_v3",
+				"topology.kubernetes.io/region":    "eastus",
+			},
+			expected: "eastus,Standard_D4s_v3,ondemand,windows",
+		},
+		{
+			name: "linux node",
+			labels: map[string]string{
+				"kubernetes.io/os":                 "linux",
+				"node.kubernetes.io/instance-type": "Standard_D4s_v3",
+				"topology.kubernetes.io/region":    "eastus",
+			},
+			expected: "eastus,Standard_D4s_v3,ondemand",
+		},
+		{
+			name: "no OS label defaults to linux key",
+			labels: map[string]string{
+				"node.kubernetes.io/instance-type": "Standard_D4s_v3",
+				"topology.kubernetes.io/region":    "eastus",
+			},
+			expected: "eastus,Standard_D4s_v3,ondemand",
+		},
+		{
+			name: "windows case-insensitive",
+			labels: map[string]string{
+				"kubernetes.io/os":                 "Windows",
+				"node.kubernetes.io/instance-type": "Standard_D4s_v3",
+				"topology.kubernetes.io/region":    "eastus",
+			},
+			expected: "eastus,Standard_D4s_v3,ondemand,windows",
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			key := &azureKey{Labels: tc.labels}
+			require.Equal(t, tc.expected, key.Features())
+		})
+	}
+}
+
 func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 	testCases := []struct {
 	testCases := []struct {
-		name             string
-		jsonResponse     string
-		expectedRetail   string
-		expectedSpot     string
-		expectedError    bool
-		expectedErrorMsg string
+		name                  string
+		jsonResponse          string
+		expectedRetail        string
+		expectedWindowsRetail string
+		expectedSpot          string
+		expectedWindowsSpot   string
+		expectedError         bool
+		expectedErrorMsg      string
 	}{
 	}{
 		{
 		{
 			name: "valid response with retail and spot prices",
 			name: "valid response with retail and spot prices",
@@ -503,7 +651,7 @@ func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 			expectedError:  false,
 			expectedError:  false,
 		},
 		},
 		{
 		{
-			name: "filters out Windows instances",
+			name: "returns separate Windows and Linux prices",
 			jsonResponse: `{
 			jsonResponse: `{
 				"BillingCurrency": "USD",
 				"BillingCurrency": "USD",
 				"CustomerEntityId": "Default",
 				"CustomerEntityId": "Default",
@@ -528,9 +676,35 @@ func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 				],
 				],
 				"Count": 2
 				"Count": 2
 			}`,
 			}`,
-			expectedRetail: "0.192000",
-			expectedSpot:   "",
-			expectedError:  false,
+			expectedRetail:        "0.192000",
+			expectedWindowsRetail: "0.500000",
+			expectedSpot:          "",
+			expectedWindowsSpot:   "",
+			expectedError:         false,
+		},
+		{
+			name: "windows spot price available",
+			jsonResponse: `{
+				"BillingCurrency": "USD",
+				"CustomerEntityId": "Default",
+				"CustomerEntityType": "Retail",
+				"Items": [
+					{
+						"currencyCode": "USD",
+						"retailPrice": 0.12,
+						"armRegionName": "eastus",
+						"productName": "Virtual Machines Dsv3 Series Windows",
+						"skuName": "D4s v3 Spot",
+						"armSkuName": "Standard_D4s_v3"
+					}
+				],
+				"Count": 1
+			}`,
+			expectedRetail:        "",
+			expectedWindowsRetail: "",
+			expectedSpot:          "",
+			expectedWindowsSpot:   "0.120000",
+			expectedError:         false,
 		},
 		},
 		{
 		{
 			name: "filters out low priority instances",
 			name: "filters out low priority instances",
@@ -600,7 +774,7 @@ func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 				Body:       io.NopCloser(bytes.NewBufferString(tc.jsonResponse)),
 				Body:       io.NopCloser(bytes.NewBufferString(tc.jsonResponse)),
 			}
 			}
 
 
-			retailPrice, spotPrice, err := extractAzureVMRetailAndSpotPrices(resp)
+			linuxRetail, windowsRetail, spotPrice, windowsSpotPrice, err := extractAzureVMRetailAndSpotPrices(resp)
 
 
 			if tc.expectedError {
 			if tc.expectedError {
 				require.Error(t, err)
 				require.Error(t, err)
@@ -609,9 +783,31 @@ func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 				}
 				}
 			} else {
 			} else {
 				require.NoError(t, err)
 				require.NoError(t, err)
-				require.Equal(t, tc.expectedRetail, retailPrice, "Retail price mismatch")
+				require.Equal(t, tc.expectedRetail, linuxRetail, "Linux retail price mismatch")
+				require.Equal(t, tc.expectedWindowsRetail, windowsRetail, "Windows retail price mismatch")
 				require.Equal(t, tc.expectedSpot, spotPrice, "Spot price mismatch")
 				require.Equal(t, tc.expectedSpot, spotPrice, "Spot price mismatch")
+				require.Equal(t, tc.expectedWindowsSpot, windowsSpotPrice, "Windows spot price mismatch")
 			}
 			}
 		})
 		})
 	}
 	}
 }
 }
+
+// failingReader is an io.Reader that always errors, used to exercise the
+// response body read-failure path in extractAzureVMRetailAndSpotPrices.
+type failingReader struct{}
+
+func (failingReader) Read(_ []byte) (int, error) {
+	return 0, fmt.Errorf("simulated read failure")
+}
+
+func Test_extractAzureVMRetailAndSpotPrices_bodyReadError(t *testing.T) {
+	resp := &http.Response{
+		StatusCode: 200,
+		Body:       io.NopCloser(failingReader{}),
+	}
+
+	_, _, _, _, err := extractAzureVMRetailAndSpotPrices(resp)
+
+	require.Error(t, err)
+	require.Contains(t, err.Error(), "error getting response")
+}