Просмотр исходного кода

Support GKE Workload Identity Federation to Provider Pricing Data (#3853)

Signed-off-by: Yannik Daellenbach <git@daellenbach.org>
Yannik Dällenbach 4 дней назад
Родитель
Сommit
6f0b923182
3 измененных файлов с 107 добавлено и 15 удалено
  1. 51 8
      pkg/cloud/gcp/provider.go
  2. 56 3
      pkg/cloud/gcp/provider_test.go
  3. 0 4
      pkg/cloud/provider/provider.go

+ 51 - 8
pkg/cloud/gcp/provider.go

@@ -6,6 +6,7 @@ import (
 	"io"
 	"math"
 	"net/http"
+	"net/url"
 	"os"
 	"path"
 	"regexp"
@@ -37,7 +38,8 @@ import (
 
 const GKE_GPU_TAG = "cloud.google.com/gke-accelerator"
 const BigqueryUpdateType = "bigqueryupdate"
-const BillingAPIURLFmt = "https://cloudbilling.googleapis.com/v1/services/6F81-5844-456A/skus?key=%s&currencyCode=%s"
+const BillingAPIURL = "https://cloudbilling.googleapis.com/v1/services/6F81-5844-456A/skus"
+const GCPCloudOAuthScope = "https://www.googleapis.com/auth/cloud-platform"
 
 const (
 	GCPHourlyPublicIPCost = 0.01
@@ -955,8 +957,38 @@ func (gcp *GCP) parsePage(r io.Reader, inputKeys map[string]models.Key, pvKeys m
 	return gcpPricingList, nextPageToken, nil
 }
 
-func (gcp *GCP) getBillingAPIURL(apiKey, currencyCode string) string {
-	return fmt.Sprintf(BillingAPIURLFmt, apiKey, currencyCode)
+func (gcp *GCP) buildBillingAPIURL(apiKey, currencyCode string) *url.URL {
+	url, err := url.Parse(BillingAPIURL)
+	if err != nil {
+		panic("BillingAPIURL must be a valid URL")
+	}
+
+	query := url.Query()
+	query.Add("currencyCode", currencyCode)
+
+	if apiKey != "" {
+		query.Add("key", apiKey)
+	}
+
+	url.RawQuery = query.Encode()
+
+	return url
+}
+
+func (gcp *GCP) getBillingAPIClientAndURL(apiKey, currencyCode string) (*http.Client, string, error) {
+	url := gcp.buildBillingAPIURL(apiKey, currencyCode)
+
+	if apiKey != "" {
+		return http.DefaultClient, url.String(), nil
+	}
+
+	googleHttpClient, err := google.DefaultClient(context.TODO(), GCPCloudOAuthScope)
+	if err != nil {
+		log.Errorf("GCP Billing API: Workload Identity detected but failed to create authenticated client: %v", err)
+		return nil, "", err
+	}
+
+	return googleHttpClient, url.String(), nil
 }
 
 func (gcp *GCP) parsePages(inputKeys map[string]models.Key, pvKeys map[string]models.PVKey) (map[string]*GCPPricing, error) {
@@ -966,7 +998,10 @@ func (gcp *GCP) parsePages(inputKeys map[string]models.Key, pvKeys map[string]mo
 		return nil, err
 	}
 
-	url := gcp.getBillingAPIURL(gcp.APIKey, c.CurrencyCode)
+	httpClient, url, err := gcp.getBillingAPIClientAndURL(gcp.APIKey, c.CurrencyCode)
+	if err != nil {
+		return nil, err
+	}
 
 	var parsePagesHelper func(string) error
 	parsePagesHelper = func(pageToken string) error {
@@ -975,7 +1010,7 @@ func (gcp *GCP) parsePages(inputKeys map[string]models.Key, pvKeys map[string]mo
 		} else if pageToken != "" {
 			url = url + "&pageToken=" + pageToken
 		}
-		resp, err := http.Get(url)
+		resp, err := httpClient.Get(url)
 		if err != nil {
 			return err
 		}
@@ -1339,7 +1374,15 @@ func (gcp *GCP) getReservedInstances() ([]*GCPReservedInstance, error) {
 		return nil, err
 	}
 
-	commitments, err := computeService.RegionCommitments.AggregatedList(gcp.ProjectID).Do()
+	projID := gcp.ProjectID
+	if projID == "" {
+		projID, err = gcp.MetadataClient.ProjectID()
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	commitments, err := computeService.RegionCommitments.AggregatedList(projID).Do()
 	if err != nil {
 		return nil, err
 	}
@@ -1568,7 +1611,7 @@ func (gcp *GCP) NodePricing(key models.Key) (*models.Node, models.PricingMetadat
 		log.Debugf("Returning pricing for node %s: %+v from SKU %s", key, n.Node, n.Name)
 
 		// Add pricing URL, but redact the key (hence, "***"")
-		meta.Source = fmt.Sprintf("Downloaded pricing from %s", gcp.getBillingAPIURL("***", c.CurrencyCode))
+		meta.Source = fmt.Sprintf("Downloaded pricing from %s", gcp.buildBillingAPIURL("***", c.CurrencyCode))
 
 		n.Node.BaseCPUPrice = gcp.BaseCPUPrice
 
@@ -1588,7 +1631,7 @@ func (gcp *GCP) NodePricing(key models.Key) (*models.Node, models.PricingMetadat
 			log.Debugf("Returning pricing for node %s: %+v from SKU %s", key, n.Node, n.Name)
 
 			// Add pricing URL, but redact the key (hence, "***"")
-			meta.Source = fmt.Sprintf("Downloaded pricing from %s", gcp.getBillingAPIURL("***", c.CurrencyCode))
+			meta.Source = fmt.Sprintf("Downloaded pricing from %s", gcp.buildBillingAPIURL("***", c.CurrencyCode))
 
 			n.Node.BaseCPUPrice = gcp.BaseCPUPrice
 

+ 56 - 3
pkg/cloud/gcp/provider_test.go

@@ -4,6 +4,8 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
+	"net/http"
+	"net/url"
 	"os"
 	"reflect"
 	"strings"
@@ -707,11 +709,62 @@ func TestGCP_findCostForDisk(t *testing.T) {
 }
 
 func TestGCP_getBillingAPIURL(t *testing.T) {
+	tests := []struct {
+		name           string
+		apiKey         string
+		currency       string
+		expectedParams map[string]string
+		absentParams   []string
+	}{
+		{
+			name:           "with API key and currency",
+			apiKey:         "test-key",
+			currency:       "USD",
+			expectedParams: map[string]string{"key": "test-key", "currencyCode": "USD"},
+		},
+		{
+			name:           "empty API key omits key param",
+			apiKey:         "",
+			currency:       "USD",
+			expectedParams: map[string]string{"currencyCode": "USD"},
+			absentParams:   []string{"key"},
+		},
+		{
+			name:           "non-USD currency",
+			apiKey:         "my-key",
+			currency:       "EUR",
+			expectedParams: map[string]string{"key": "my-key", "currencyCode": "EUR"},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			gcp := &GCP{}
+			query := gcp.buildBillingAPIURL(tt.apiKey, tt.currency).Query()
+
+			for param, expected := range tt.expectedParams {
+				assert.Equal(t, expected, query.Get(param), "query param %q", param)
+			}
+			for _, param := range tt.absentParams {
+				assert.False(t, query.Has(param), "query param %q should be absent", param)
+			}
+		})
+	}
+}
+
+func TestGCP_getBillingAPIClientAndURL(t *testing.T) {
 	gcp := &GCP{}
 
-	url := gcp.getBillingAPIURL("test-key", "USD")
-	expected := "https://cloudbilling.googleapis.com/v1/services/6F81-5844-456A/skus?key=test-key&currencyCode=USD"
-	assert.Equal(t, expected, url)
+	client, rawURL, err := gcp.getBillingAPIClientAndURL("test-key", "USD")
+
+	assert.NoError(t, err)
+	assert.Equal(t, http.DefaultClient, client)
+
+	parsedURL, err := url.Parse(rawURL)
+	assert.NoError(t, err)
+	query := parsedURL.Query()
+	assert.Equal(t, "test-key", query.Get("key"))
+	assert.Equal(t, "USD", query.Get("currencyCode"))
 }
 
 func TestGCP_GpuPricing(t *testing.T) {

+ 0 - 4
pkg/cloud/provider/provider.go

@@ -2,7 +2,6 @@ package provider
 
 import (
 	"context"
-	"errors"
 	"fmt"
 	"net"
 	"net/http"
@@ -143,9 +142,6 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string, config *config.
 		}, nil
 	case opencost.GCPProvider:
 		log.Info("Found ProviderID starting with \"gce\", using GCP Provider")
-		if apiKey == "" {
-			return nil, errors.New("Supply a GCP Key to start getting data")
-		}
 		return &gcp.GCP{
 			Clientset:        cache,
 			APIKey:           apiKey,