Răsfoiți Sursa

Fix status code check in getRetailPrice and add unit tests

Fix logical operator bug (&&→||) in non-2xx status code check that made
the condition impossible to satisfy. Add defer resp.Body.Close() to
prevent resource leaks. Extract base URL into a package-level variable
to enable test server injection, and add comprehensive unit tests
covering getRetailPrice rejection of non-2xx HTTP responses.

Addresses review feedback from PR #3702.
Claude 2 luni în urmă
părinte
comite
b1624fd821
2 a modificat fișierele cu 98 adăugiri și 2 ștergeri
  1. 7 2
      pkg/cloud/azure/provider.go
  2. 91 0
      pkg/cloud/azure/provider_test.go

+ 7 - 2
pkg/cloud/azure/provider.go

@@ -239,8 +239,12 @@ func getRegions(service string, subscriptionsClient subscriptions.Client, provid
 	}
 }
 
+// azureRetailPricesBaseURL is the base URL for Azure retail prices API.
+// It is a variable so that tests can override it with a local test server.
+var azureRetailPricesBaseURL = "https://prices.azure.com/api/retail/prices"
+
 func buildAzureRetailPricesURL(region string, skuName string, currencyCode string) string {
-	pricingURL := "https://prices.azure.com/api/retail/prices?$skip=0"
+	pricingURL := azureRetailPricesBaseURL + "?$skip=0"
 
 	if currencyCode != "" {
 		pricingURL += fmt.Sprintf("&currencyCode='%s'", currencyCode)
@@ -306,8 +310,9 @@ func getRetailPrice(region string, skuName string, currencyCode string, spot boo
 	if err != nil {
 		return "", fmt.Errorf("failed to fetch retail price with URL \"%s\": %v", pricingURL, err)
 	}
+	defer resp.Body.Close()
 
-	if resp.StatusCode < 200 && resp.StatusCode > 299 {
+	if resp.StatusCode < 200 || resp.StatusCode > 299 {
 		return "", fmt.Errorf("retail price responded with error status code %d", resp.StatusCode)
 	}
 

+ 91 - 0
pkg/cloud/azure/provider_test.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"net/http/httptest"
 	"testing"
 
 	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute"
@@ -553,3 +554,93 @@ func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
 		})
 	}
 }
+
+func Test_getRetailPrice_NonSuccessStatusCodes(t *testing.T) {
+	testCases := []struct {
+		name       string
+		statusCode int
+		wantErr    bool
+		errMsg     string
+	}{
+		{
+			name:       "returns error on 400 Bad Request",
+			statusCode: http.StatusBadRequest,
+			wantErr:    true,
+			errMsg:     "retail price responded with error status code 400",
+		},
+		{
+			name:       "returns error on 404 Not Found",
+			statusCode: http.StatusNotFound,
+			wantErr:    true,
+			errMsg:     "retail price responded with error status code 404",
+		},
+		{
+			name:       "returns error on 500 Internal Server Error",
+			statusCode: http.StatusInternalServerError,
+			wantErr:    true,
+			errMsg:     "retail price responded with error status code 500",
+		},
+		{
+			name:       "returns error on 403 Forbidden",
+			statusCode: http.StatusForbidden,
+			wantErr:    true,
+			errMsg:     "retail price responded with error status code 403",
+		},
+		{
+			name:       "returns error on 503 Service Unavailable",
+			statusCode: http.StatusServiceUnavailable,
+			wantErr:    true,
+			errMsg:     "retail price responded with error status code 503",
+		},
+		{
+			name:       "succeeds on 200 OK",
+			statusCode: http.StatusOK,
+			wantErr:    false,
+		},
+	}
+
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) {
+			// Create a test server that returns the specified status code
+			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				w.Header().Set("Content-Type", "application/json")
+				w.WriteHeader(tc.statusCode)
+				// Write a valid response body for 200 OK cases
+				if tc.statusCode == http.StatusOK {
+					fmt.Fprint(w, `{
+						"Items": [
+							{
+								"currencyCode": "USD",
+								"retailPrice": 0.096,
+								"armRegionName": "eastus",
+								"productName": "Virtual Machines Dsv3 Series",
+								"skuName": "D2s v3",
+								"armSkuName": "Standard_D2s_v3"
+							}
+						],
+						"Count": 1
+					}`)
+				} else {
+					fmt.Fprint(w, `{"error": "test error"}`)
+				}
+			}))
+			defer server.Close()
+
+			// Override the base URL to point to our test server
+			originalURL := azureRetailPricesBaseURL
+			azureRetailPricesBaseURL = server.URL
+			defer func() { azureRetailPricesBaseURL = originalURL }()
+
+			result, err := getRetailPrice("eastus", "Standard_D2s_v3", "USD", false)
+
+			if tc.wantErr {
+				require.Error(t, err)
+				require.Contains(t, err.Error(), tc.errMsg)
+				require.Empty(t, result)
+			} else {
+				require.NoError(t, err)
+				require.NotEmpty(t, result)
+			}
+		})
+	}
+}