|
|
@@ -5,6 +5,7 @@ import (
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
+ "net/http/httptest"
|
|
|
"testing"
|
|
|
|
|
|
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute"
|
|
|
@@ -553,3 +554,93 @@ func Test_extractAzureVMRetailAndSpotPrices(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func Test_getRetailPrice_NonSuccessStatusCodes(t *testing.T) {
|
|
|
+ testCases := []struct {
|
|
|
+ name string
|
|
|
+ statusCode int
|
|
|
+ wantErr bool
|
|
|
+ errMsg string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "returns error on 400 Bad Request",
|
|
|
+ statusCode: http.StatusBadRequest,
|
|
|
+ wantErr: true,
|
|
|
+ errMsg: "retail price responded with error status code 400",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "returns error on 404 Not Found",
|
|
|
+ statusCode: http.StatusNotFound,
|
|
|
+ wantErr: true,
|
|
|
+ errMsg: "retail price responded with error status code 404",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "returns error on 500 Internal Server Error",
|
|
|
+ statusCode: http.StatusInternalServerError,
|
|
|
+ wantErr: true,
|
|
|
+ errMsg: "retail price responded with error status code 500",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "returns error on 403 Forbidden",
|
|
|
+ statusCode: http.StatusForbidden,
|
|
|
+ wantErr: true,
|
|
|
+ errMsg: "retail price responded with error status code 403",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "returns error on 503 Service Unavailable",
|
|
|
+ statusCode: http.StatusServiceUnavailable,
|
|
|
+ wantErr: true,
|
|
|
+ errMsg: "retail price responded with error status code 503",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "succeeds on 200 OK",
|
|
|
+ statusCode: http.StatusOK,
|
|
|
+ wantErr: false,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tc := range testCases {
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ // Create a test server that returns the specified status code
|
|
|
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ w.Header().Set("Content-Type", "application/json")
|
|
|
+ w.WriteHeader(tc.statusCode)
|
|
|
+ // Write a valid response body for 200 OK cases
|
|
|
+ if tc.statusCode == http.StatusOK {
|
|
|
+ fmt.Fprint(w, `{
|
|
|
+ "Items": [
|
|
|
+ {
|
|
|
+ "currencyCode": "USD",
|
|
|
+ "retailPrice": 0.096,
|
|
|
+ "armRegionName": "eastus",
|
|
|
+ "productName": "Virtual Machines Dsv3 Series",
|
|
|
+ "skuName": "D2s v3",
|
|
|
+ "armSkuName": "Standard_D2s_v3"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "Count": 1
|
|
|
+ }`)
|
|
|
+ } else {
|
|
|
+ fmt.Fprint(w, `{"error": "test error"}`)
|
|
|
+ }
|
|
|
+ }))
|
|
|
+ defer server.Close()
|
|
|
+
|
|
|
+ // Override the base URL to point to our test server
|
|
|
+ originalURL := azureRetailPricesBaseURL
|
|
|
+ azureRetailPricesBaseURL = server.URL
|
|
|
+ defer func() { azureRetailPricesBaseURL = originalURL }()
|
|
|
+
|
|
|
+ result, err := getRetailPrice("eastus", "Standard_D2s_v3", "USD", false)
|
|
|
+
|
|
|
+ if tc.wantErr {
|
|
|
+ require.Error(t, err)
|
|
|
+ require.Contains(t, err.Error(), tc.errMsg)
|
|
|
+ require.Empty(t, result)
|
|
|
+ } else {
|
|
|
+ require.NoError(t, err)
|
|
|
+ require.NotEmpty(t, result)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|