Просмотр исходного кода

add unit tests and other suggestions

Sean Holcomb 4 лет назад
Родитель
Сommit
913e6293a8

+ 2 - 2
pkg/cloud/azureprovider.go

@@ -1424,8 +1424,8 @@ func (az *Azure) Regions() []string {
 }
 
 func parseAzureSubscriptionID(id string) string {
-	// azure:///subscriptions/0bd50fdf-c923-4e1e-850c-196dd3dcc5d3/...
-	//  => 0bd50fdf-c923-4e1e-850c-196dd3dcc5d3
+	// azure:///subscriptions/0badafdf-1234-abcd-wxyz-123456789/...
+	//  => 0badafdf-1234-abcd-wxyz-123456789
 	rx := regexp.MustCompile("azure:///subscriptions/([^/]*)/*")
 	match := rx.FindStringSubmatch(id)
 	if len(match) >= 2 {

+ 36 - 0
pkg/cloud/azureprovider_test.go

@@ -0,0 +1,36 @@
+package cloud
+
+import (
+	"testing"
+)
+
+func TestParseAzureSubscriptionID(t *testing.T) {
+	cases := []struct {
+		input    string
+		expected string
+	}{
+		{
+			input:    "azure:///subscriptions/0badafdf-1234-abcd-wxyz-123456789/...",
+			expected: "0badafdf-1234-abcd-wxyz-123456789",
+		},
+		{
+			input:    "azure:/subscriptions/0badafdf-1234-abcd-wxyz-123456789/...",
+			expected: "",
+		},
+		{
+			input:    "azure:///subscriptions//",
+			expected: "",
+		},
+		{
+			input:    "",
+			expected: "",
+		},
+	}
+
+	for _, test := range cases {
+		result := parseAzureSubscriptionID(test.input)
+		if result != test.expected {
+			t.Errorf("Input: %s, Expected: %s, Actual: %s", test.input, test.expected, result)
+		}
+	}
+}

+ 3 - 3
pkg/cloud/gcpprovider.go

@@ -1529,9 +1529,9 @@ func sustainedUseDiscount(class string, defaultDiscount float64, isPreemptible b
 	return discount
 }
 
-func parseGCPProjectID (id string) string {
-	// gce://guestbook-227502/...
-	//  => guestbook-227502
+func parseGCPProjectID(id string) string {
+	// gce://guestbook-12345/...
+	//  => guestbook-12345
 	rx := regexp.MustCompile("gce://([^/]*)/*")
 	match := rx.FindStringSubmatch(id)
 	if len(match) >= 2 {

+ 31 - 0
pkg/cloud/gcpprovider_test.go

@@ -34,3 +34,34 @@ func TestParseGCPInstanceTypeLabel(t *testing.T) {
 		}
 	}
 }
+
+func TestParseGCPProjectID(t *testing.T) {
+	cases := []struct {
+		input    string
+		expected string
+	}{
+		{
+			input:    "gce://guestbook-12345/...",
+			expected: "guestbook-12345",
+		},
+		{
+			input:    "gce:/guestbook-12345/...",
+			expected: "",
+		},
+		{
+			input:    "asdfa",
+			expected: "",
+		},
+		{
+			input:    "",
+			expected: "",
+		},
+	}
+
+	for _, test := range cases {
+		result := parseGCPProjectID(test.input)
+		if result != test.expected {
+			t.Errorf("Input: %s, Expected: %s, Actual: %s", test.input, test.expected, result)
+		}
+	}
+}

+ 43 - 33
pkg/cloud/provider.go

@@ -412,16 +412,16 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string) (Provider, erro
 		return nil, fmt.Errorf("Could not locate any nodes for cluster.")
 	}
 
-	provider, configFileName, region, accountID, projectID := getClusterProperties(nodes[0])
+	cp := getClusterProperties(nodes[0])
 
-	switch provider {
+	switch cp.provider {
 	case "CSV":
 		klog.Infof("Using CSV Provider with CSV at %s", env.GetCSVPath())
 		return &CSVProvider{
 			CSVLocation: env.GetCSVPath(),
 			CustomProvider: &CustomProvider{
 				Clientset: cache,
-				Config:    NewProviderConfig(configFileName),
+				Config:    NewProviderConfig(cp.configFileName),
 			},
 		}, nil
 	case "GCP":
@@ -430,61 +430,71 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string) (Provider, erro
 			return nil, errors.New("Supply a GCP Key to start getting data")
 		}
 		return &GCP{
-			Clientset: cache,
-			APIKey:    apiKey,
-			Config:    NewProviderConfig(configFileName),
-			clusterRegion: region,
-			clusterProjectId: projectID,
+			Clientset:        cache,
+			APIKey:           apiKey,
+			Config:           NewProviderConfig(cp.configFileName),
+			clusterRegion:    cp.region,
+			clusterProjectId: cp.projectID,
 		}, nil
 	case "AWS":
 		klog.V(2).Info("Found ProviderID starting with \"aws\", using AWS Provider")
 		return &AWS{
-			Clientset: cache,
-			Config:    NewProviderConfig(configFileName),
-			clusterRegion: region,
-			clusterAccountId: accountID,
+			Clientset:        cache,
+			Config:           NewProviderConfig(cp.configFileName),
+			clusterRegion:    cp.region,
+			clusterAccountId: cp.accountID,
 		}, nil
 	case "AZURE":
 		klog.V(2).Info("Found ProviderID starting with \"azure\", using Azure Provider")
 		return &Azure{
-			Clientset: cache,
-			Config:    NewProviderConfig(configFileName),
-			clusterRegion: region,
-			clusterAccountId: accountID,
+			Clientset:        cache,
+			Config:           NewProviderConfig(cp.configFileName),
+			clusterRegion:    cp.region,
+			clusterAccountId: cp.accountID,
 		}, nil
 	default:
 		klog.V(2).Info("Unsupported provider, falling back to default")
 		return &CustomProvider{
 			Clientset: cache,
-			Config:    NewProviderConfig(configFileName),
+			Config:    NewProviderConfig(cp.configFileName),
 		}, nil
 	}
 }
 
-func getClusterProperties(node *v1.Node) (string, string, string, string, string) {
+type clusterProperties struct {
+	provider       string
+	configFileName string
+	region         string
+	accountID      string
+	projectID      string
+}
+
+func getClusterProperties(node *v1.Node) (clusterProperties) {
 	providerID := strings.ToLower(node.Spec.ProviderID)
-	provider := "DEFAULT"
-	configFileName := "default.json"
-	region := node.Labels["topology.kubernetes.io/region"]
-	accountID := ""
-	projectID := ""
+	cp := clusterProperties{
+		provider: "DEFAULT",
+		configFileName: "default.json",
+		region: node.Labels["topology.kubernetes.io/region"],
+		accountID: "",
+		projectID: "",
+	}
 	if metadata.OnGCE() {
-		provider = "GCP"
-		configFileName = "gcp.json"
-		projectID = parseGCPProjectID(providerID)
+		cp.provider = "GCP"
+		cp.configFileName = "gcp.json"
+		cp.projectID = parseGCPProjectID(providerID)
 	} else if strings.HasPrefix(providerID, "aws") {
-		provider = "AWS"
-		configFileName = "aws.json"
+		cp.provider = "AWS"
+		cp.configFileName = "aws.json"
 	} else if strings.HasPrefix(providerID, "azure") {
-		provider = "AZURE"
-		configFileName = "azure.json"
-		accountID = parseAzureSubscriptionID(providerID)
+		cp.provider = "AZURE"
+		cp.configFileName = "azure.json"
+		cp.accountID = parseAzureSubscriptionID(providerID)
 	}
 	if env.IsUseCSVProvider() {
-		provider = "CSV"
+		cp.provider = "CSV"
 	}
 
-	return provider, configFileName, region, accountID, projectID
+	return cp
 }
 
 func UpdateClusterMeta(cluster_id, cluster_name string) error {