2
0
Эх сурвалжийг харах

Move pricesheetDownloader to azurepricesheet.Downloader

This avoids polluting the cloud package.

Pass a function to convert meter info into pricings in as config for
the downloader so we can share the convertMeterToPricings function
between the rate card and price sheet paths.

Similarly, since AzurePricing is defined in the cloud package, make
Downloader generic over that type. I'm not sure about this, it seems a
little clever - is there a sensible place to move AzurePricing to so
it can be imported in both pkg/cloud and pkg/cloud/azurepricesheet?

Signed-off-by: Christian Muirhead <christian.muirhead@microsoft.com>
Christian Muirhead 3 жил өмнө
parent
commit
3f707c380c

+ 252 - 0
pkg/cloud/azurepricesheet/downloader.go

@@ -0,0 +1,252 @@
+package azurepricesheet
+
+import (
+	"bufio"
+	"context"
+	"encoding/csv"
+	"fmt"
+	"io"
+	"net/http"
+	"os"
+	"sort"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/Azure/azure-sdk-for-go/profiles/2020-09-01/commerce/mgmt/commerce"
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
+	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
+
+	"github.com/opencost/opencost/pkg/log"
+)
+
+type Downloader[T any] struct {
+	TenantID         string
+	ClientID         string
+	ClientSecret     string
+	BillingAccount   string
+	OfferID          string
+	ConvertMeterInfo func(info commerce.MeterInfo) (map[string]*T, error)
+}
+
+func (d *Downloader[T]) Run(ctx context.Context) (map[string]*T, error) {
+	log.Infof("requesting pricesheet download link")
+	url, err := d.getPricesheetDownloadURL(ctx)
+	if err != nil {
+		return nil, fmt.Errorf("getting download URL: %w", err)
+	}
+	log.Infof("downloading pricesheet from %q", url)
+	data, err := d.saveData(ctx, url, "pricesheet")
+	if err != nil {
+		return nil, fmt.Errorf("saving pricesheet from %q: %w", url, err)
+	}
+	defer data.Close()
+
+	prices, err := d.readPricesheet(ctx, data)
+	if err != nil {
+		return nil, fmt.Errorf("reading pricesheet: %w", err)
+	}
+	log.Infof("loaded %d pricings from pricesheet", len(prices))
+	return prices, nil
+}
+
+func (d *Downloader[T]) getPricesheetDownloadURL(ctx context.Context) (string, error) {
+	cred, err := azidentity.NewClientSecretCredential(d.TenantID, d.ClientID, d.ClientSecret, nil)
+	if err != nil {
+		return "", fmt.Errorf("creating credential: %w", err)
+	}
+	client, err := NewClient(d.BillingAccount, cred, nil)
+	if err != nil {
+		return "", fmt.Errorf("creating pricesheet client: %w", err)
+	}
+	poller, err := client.BeginDownloadByBillingPeriod(ctx, currentBillingPeriod())
+	if err != nil {
+		return "", fmt.Errorf("beginning pricesheet download: %w", err)
+	}
+	resp, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
+		Frequency: 30 * time.Second,
+	})
+	if err != nil {
+		return "", fmt.Errorf("polling for pricesheet: %w", err)
+	}
+	return resp.Properties.DownloadURL, nil
+}
+
+func (d Downloader[T]) saveData(ctx context.Context, url, tempName string) (io.ReadCloser, error) {
+	// Download file from URL in response.
+	out, err := os.CreateTemp("", tempName)
+	if err != nil {
+		return nil, fmt.Errorf("creating %s temp file: %w", tempName, err)
+	}
+
+	resp, err := http.Get(url)
+	if err != nil {
+		return nil, fmt.Errorf("downloading: %w", err)
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("unexpected HTTP status %d", resp.StatusCode)
+	}
+
+	if _, err := io.Copy(out, resp.Body); err != nil {
+		return nil, fmt.Errorf("reading response: %w", err)
+	}
+
+	_, err = out.Seek(0, io.SeekStart)
+	if err != nil {
+		return nil, fmt.Errorf("seeking to start of file: %w", err)
+	}
+
+	return out, nil
+}
+
+func (d *Downloader[T]) readPricesheet(ctx context.Context, data io.Reader) (map[string]*T, error) {
+	// Avoid double-buffering.
+	buf, ok := (data).(*bufio.Reader)
+	if !ok {
+		buf = bufio.NewReader(data)
+	}
+
+	// The CSV file starts with two lines before the header without
+	// commas (so different numbers of fields as far as the CSV parser
+	// is concerned). Skip them before making the CSV reader so we
+	// still get the benefit of the row length checks after the
+	// header.
+	for i := 0; i < 2; i++ {
+		_, err := buf.ReadBytes('\n')
+		if err != nil {
+			return nil, fmt.Errorf("skipping preamble line %d: %w", i, err)
+		}
+	}
+	reader := csv.NewReader(buf)
+	reader.ReuseRecord = true
+
+	header, err := reader.Read()
+	if err != nil {
+		return nil, fmt.Errorf("reading header: %w", err)
+	}
+	if err := checkPricesheetHeader(header); err != nil {
+		return nil, err
+	}
+
+	units := make(map[string]bool)
+
+	results := make(map[string]*T)
+	lines := 2
+	for {
+		row, err := reader.Read()
+		if err == io.EOF {
+			break
+		}
+		lines++
+		if err != nil {
+			return nil, fmt.Errorf("reading line %d: %w", lines, err)
+		}
+
+		// Skip savings plan - we should be reporting based on the
+		// consumption price because we don't know whether the user is
+		// using a savings plan or over their threshold.
+		if row[pricesheetPriceType] == "Savings Plan" || row[pricesheetOfferID] != d.OfferID {
+			continue
+		}
+
+		// TODO: Creating a meter info for each record will cause a
+		// lot of GC churn - is it worth reusing one meter info instead?
+		meterInfo, err := makeMeterInfo(row)
+		if err != nil {
+			log.Warnf("making meter info (line %d): %v", lines, err)
+			continue
+		}
+
+		pricings, err := d.ConvertMeterInfo(meterInfo)
+		if err != nil {
+			log.Warnf("converting meter to pricings (line %d): %v", lines, err)
+			continue
+		}
+
+		if pricings != nil {
+			units[*meterInfo.Unit] = true
+		}
+
+		// TODO: add pricings for AzureFileStandardStorageClass
+
+		for key, pricing := range pricings {
+			results[key] = pricing
+		}
+	}
+
+	if len(results) == 0 {
+		return nil, fmt.Errorf("no matching pricing from pricesheet")
+	}
+
+	// This is temporary, gathering info while adding unit normalisation.
+	allUnits := make([]string, 0, len(units))
+	for unit := range units {
+		allUnits = append(allUnits, unit)
+	}
+	sort.Strings(allUnits)
+	log.Infof("all units in pricesheet: %s", strings.Join(allUnits, ", "))
+
+	return results, nil
+}
+
+func checkPricesheetHeader(header []string) error {
+	for name, col := range pricesheetCols {
+		if !strings.EqualFold(header[col], name) {
+			return fmt.Errorf("unexpected header %q, expected %q", header[col], name)
+		}
+	}
+	return nil
+}
+
+func makeMeterInfo(row []string) (commerce.MeterInfo, error) {
+	price, err := strconv.ParseFloat(row[pricesheetUnitPrice], 64)
+	if err != nil {
+		return commerce.MeterInfo{}, fmt.Errorf("parsing unit price: %w", err)
+	}
+	// TODO: normalize units - some meters are for 1 hour or 1
+	// GB/Month, others are for 10 or 100.
+	return commerce.MeterInfo{
+		MeterName:        ptr(row[pricesheetMeterName]),
+		MeterCategory:    ptr(row[pricesheetMeterCategory]),
+		MeterSubCategory: ptr(row[pricesheetMeterSubCategory]),
+		Unit:             ptr(row[pricesheetUnit]),
+		MeterRegion:      ptr(row[pricesheetMeterRegion]),
+		MeterRates:       map[string]*float64{"0": &price},
+	}, nil
+}
+
+var pricesheetCols = map[string]int{
+	"Meter ID":           pricesheetMeterID,
+	"Meter name":         pricesheetMeterName,
+	"Meter category":     pricesheetMeterCategory,
+	"Meter sub-category": pricesheetMeterSubCategory,
+	"Meter region":       pricesheetMeterRegion,
+	"Unit":               pricesheetUnit,
+	"Unit price":         pricesheetUnitPrice,
+	"Currency code":      pricesheetCurrencyCode,
+	"Offer Id":           pricesheetOfferID,
+	"Price type":         pricesheetPriceType,
+}
+
+const (
+	pricesheetMeterID          = 0
+	pricesheetMeterName        = 1
+	pricesheetMeterCategory    = 2
+	pricesheetMeterSubCategory = 3
+	pricesheetMeterRegion      = 4
+	pricesheetUnit             = 5
+	pricesheetUnitPrice        = 8
+	pricesheetCurrencyCode     = 9
+	pricesheetOfferID          = 11
+	pricesheetPriceType        = 13
+)
+
+func currentBillingPeriod() string {
+	return time.Now().Format("200601")
+}
+
+func ptr[T any](v T) *T {
+	return &v
+}

+ 11 - 247
pkg/cloud/azureprovider.go

@@ -1,23 +1,18 @@
 package cloud
 
 import (
-	"bufio"
 	"context"
-	"encoding/csv"
 	"fmt"
 	"io"
 	"net/http"
 	"net/url"
 	"os"
 	"regexp"
-	"sort"
 	"strconv"
 	"strings"
 	"sync"
 	"time"
 
-	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
-	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
 	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute"
 	"github.com/Azure/azure-sdk-for-go/services/preview/commerce/mgmt/2015-06-01-preview/commerce"
 	"github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2016-06-01/subscriptions"
@@ -889,20 +884,21 @@ func (az *Azure) DownloadPricingData() error {
 
 	// If we've got a billing account set, kick off downloading the custom pricing data.
 	if config.AzureBillingAccount != "" {
-		downloader := pricesheetDownloader{
-			tenantID:       config.AzureTenantID,
-			clientID:       config.AzureClientID,
-			clientSecret:   config.AzureClientSecret,
-			billingAccount: config.AzureBillingAccount,
-			offerID:        config.AzureOfferDurableID,
-			regions:        regions,
-			baseCPUPrice:   baseCPUPrice,
+		downloader := pricesheet.Downloader[AzurePricing]{
+			TenantID:       config.AzureTenantID,
+			ClientID:       config.AzureClientID,
+			ClientSecret:   config.AzureClientSecret,
+			BillingAccount: config.AzureBillingAccount,
+			OfferID:        config.AzureOfferDurableID,
+			ConvertMeterInfo: func(meterInfo commerce.MeterInfo) (map[string]*AzurePricing, error) {
+				return convertMeterToPricings(meterInfo, regions, baseCPUPrice)
+			},
 		}
 		// The price sheet can take 5 minutes to generate, so we don't
 		// want to hang onto the lock while we're waiting for it.
 		go func() {
 			ctx := context.Background()
-			allPrices, err := downloader.run(ctx)
+			allPrices, err := downloader.Run(ctx)
 
 			az.DownloadPricingDataLock.Lock()
 			defer az.DownloadPricingDataLock.Unlock()
@@ -965,7 +961,7 @@ func convertMeterToPricings(info commerce.MeterInfo, regions map[string]string,
 				key := region + "," + storageClass
 				log.Debugf("Adding PV.Key: %s, Cost: %s", key, priceStr)
 				return map[string]*AzurePricing{
-					key: &AzurePricing{
+					key: {
 						PV: &PV{
 							Cost:   priceStr,
 							Region: region,
@@ -1028,238 +1024,6 @@ func convertMeterToPricings(info commerce.MeterInfo, regions map[string]string,
 
 }
 
-type pricesheetDownloader struct {
-	tenantID       string
-	clientID       string
-	clientSecret   string
-	billingAccount string
-	offerID        string
-	regions        map[string]string
-	baseCPUPrice   string
-}
-
-func (d *pricesheetDownloader) run(ctx context.Context) (map[string]*AzurePricing, error) {
-	log.Infof("requesting pricesheet download link")
-	url, err := d.getPricesheetDownloadURL(ctx)
-	if err != nil {
-		return nil, fmt.Errorf("getting download URL: %w", err)
-	}
-	log.Infof("downloading pricesheet from %q", url)
-	data, err := d.saveData(ctx, url, "pricesheet")
-	if err != nil {
-		return nil, fmt.Errorf("saving pricesheet from %q: %w", url, err)
-	}
-	defer data.Close()
-
-	prices, err := d.readPricesheet(ctx, data)
-	if err != nil {
-		return nil, fmt.Errorf("reading pricesheet: %w", err)
-	}
-	log.Infof("loaded %d pricings from pricesheet", len(prices))
-	return prices, nil
-}
-
-func (d *pricesheetDownloader) getPricesheetDownloadURL(ctx context.Context) (string, error) {
-	cred, err := azidentity.NewClientSecretCredential(d.tenantID, d.clientID, d.clientSecret, nil)
-	if err != nil {
-		return "", fmt.Errorf("creating credential: %w", err)
-	}
-	client, err := pricesheet.NewClient(d.billingAccount, cred, nil)
-	if err != nil {
-		return "", fmt.Errorf("creating pricesheet client: %w", err)
-	}
-	poller, err := client.BeginDownloadByBillingPeriod(ctx, currentBillingPeriod())
-	if err != nil {
-		return "", fmt.Errorf("beginning pricesheet download: %w", err)
-	}
-	resp, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
-		Frequency: 30 * time.Second,
-	})
-	if err != nil {
-		return "", fmt.Errorf("polling for pricesheet: %w", err)
-	}
-	return resp.Properties.DownloadURL, nil
-}
-
-func (d pricesheetDownloader) saveData(ctx context.Context, url, tempName string) (io.ReadCloser, error) {
-	// Download file from URL in response.
-	out, err := os.CreateTemp("", tempName)
-	if err != nil {
-		return nil, fmt.Errorf("creating %s temp file: %w", tempName, err)
-	}
-
-	resp, err := http.Get(url)
-	if err != nil {
-		return nil, fmt.Errorf("downloading: %w", err)
-	}
-	defer resp.Body.Close()
-
-	if resp.StatusCode != http.StatusOK {
-		return nil, fmt.Errorf("unexpected HTTP status %d", resp.StatusCode)
-	}
-
-	if _, err := io.Copy(out, resp.Body); err != nil {
-		return nil, fmt.Errorf("reading response: %w", err)
-	}
-
-	_, err = out.Seek(0, io.SeekStart)
-	if err != nil {
-		return nil, fmt.Errorf("seeking to start of file: %w", err)
-	}
-
-	return out, nil
-}
-
-func (d *pricesheetDownloader) readPricesheet(ctx context.Context, data io.Reader) (map[string]*AzurePricing, error) {
-	// Avoid double-buffering.
-	buf, ok := (data).(*bufio.Reader)
-	if !ok {
-		buf = bufio.NewReader(data)
-	}
-
-	// The CSV file starts with two lines before the header without
-	// commas (so different numbers of fields as far as the CSV parser
-	// is concerned). Skip them before making the CSV reader so we
-	// still get the benefit of the row length checks after the
-	// header.
-	for i := 0; i < 2; i++ {
-		_, err := buf.ReadBytes('\n')
-		if err != nil {
-			return nil, fmt.Errorf("skipping preamble line %d: %w", i, err)
-		}
-	}
-	reader := csv.NewReader(buf)
-	reader.ReuseRecord = true
-
-	header, err := reader.Read()
-	if err != nil {
-		return nil, fmt.Errorf("reading header: %w", err)
-	}
-	if err := checkPricesheetHeader(header); err != nil {
-		return nil, err
-	}
-
-	units := make(map[string]bool)
-
-	results := make(map[string]*AzurePricing)
-	lines := 2
-	for {
-		row, err := reader.Read()
-		if err == io.EOF {
-			break
-		}
-		lines++
-		if err != nil {
-			return nil, fmt.Errorf("reading line %d: %w", lines, err)
-		}
-
-		// Skip savings plan - we should be reporting based on the
-		// consumption price because we don't know whether the user is
-		// using a savings plan or over their threshold.
-		if row[pricesheetPriceType] == "Savings Plan" || row[pricesheetOfferID] != d.offerID {
-			continue
-		}
-
-		// TODO: Creating a meter info for each record will cause a
-		// lot of GC churn - is it worth reusing one meter info instead?
-		meterInfo, err := makeMeterInfo(row)
-		if err != nil {
-			log.Warnf("making meter info (line %d): %v", lines, err)
-			continue
-		}
-
-		pricings, err := convertMeterToPricings(meterInfo, d.regions, d.baseCPUPrice)
-		if err != nil {
-			log.Warnf("converting meter to pricings (line %d): %v", lines, err)
-			continue
-		}
-
-		if pricings != nil {
-			units[*meterInfo.Unit] = true
-		}
-
-		// TODO: add pricings for AzureFileStandardStorageClass
-
-		for key, pricing := range pricings {
-			results[key] = pricing
-		}
-	}
-
-	if len(results) == 0 {
-		return nil, fmt.Errorf("no matching pricing from pricesheet")
-	}
-
-	// This is temporary, gathering info while adding unit normalisation.
-	allUnits := make([]string, 0, len(units))
-	for unit := range units {
-		allUnits = append(allUnits, unit)
-	}
-	sort.Strings(allUnits)
-	log.Infof("all units in pricesheet: %s", strings.Join(allUnits, ", "))
-
-	return results, nil
-}
-
-func checkPricesheetHeader(header []string) error {
-	for name, col := range pricesheetCols {
-		if !strings.EqualFold(header[col], name) {
-			return fmt.Errorf("unexpected header %q, expected %q", header[col], name)
-		}
-	}
-	return nil
-}
-
-func makeMeterInfo(row []string) (commerce.MeterInfo, error) {
-	price, err := strconv.ParseFloat(row[pricesheetUnitPrice], 64)
-	if err != nil {
-		return commerce.MeterInfo{}, fmt.Errorf("parsing unit price: %w", err)
-	}
-	// TODO: normalize units - some meters are for 1 hour or 1
-	// GB/Month, others are for 10 or 100.
-	return commerce.MeterInfo{
-		MeterName:        ptr(row[pricesheetMeterName]),
-		MeterCategory:    ptr(row[pricesheetMeterCategory]),
-		MeterSubCategory: ptr(row[pricesheetMeterSubCategory]),
-		Unit:             ptr(row[pricesheetUnit]),
-		MeterRegion:      ptr(row[pricesheetMeterRegion]),
-		MeterRates:       map[string]*float64{"0": &price},
-	}, nil
-}
-
-var pricesheetCols = map[string]int{
-	"Meter ID":           pricesheetMeterID,
-	"Meter name":         pricesheetMeterName,
-	"Meter category":     pricesheetMeterCategory,
-	"Meter sub-category": pricesheetMeterSubCategory,
-	"Meter region":       pricesheetMeterRegion,
-	"Unit":               pricesheetUnit,
-	"Unit price":         pricesheetUnitPrice,
-	"Currency code":      pricesheetCurrencyCode,
-	"Offer Id":           pricesheetOfferID,
-	"Price type":         pricesheetPriceType,
-}
-
-const (
-	pricesheetMeterID          = 0
-	pricesheetMeterName        = 1
-	pricesheetMeterCategory    = 2
-	pricesheetMeterSubCategory = 3
-	pricesheetMeterRegion      = 4
-	pricesheetUnit             = 5
-	pricesheetUnitPrice        = 8
-	pricesheetCurrencyCode     = 9
-	pricesheetOfferID          = 11
-	pricesheetPriceType        = 13
-)
-
-func currentBillingPeriod() string {
-	return time.Now().Format("200601")
-}
-
-func ptr[T any](v T) *T {
-	return &v
-}
-
 // determineCloudByRegion uses region name to pick the correct Cloud Environment for the azure provider to use
 func determineCloudByRegion(region string) azure.Environment {
 	lcRegion := strings.ToLower(region)