Преглед на файлове

Merge pull request #1842 from babbageclunk/azure-pricesheet

Support for Azure consumption pricesheet download API
Ajay Tripathy преди 3 години
родител
ревизия
2d15f3cfea

+ 1 - 1
go.mod

@@ -8,7 +8,7 @@ require (
 	cloud.google.com/go/storage v1.28.1
 	github.com/Azure/azure-pipeline-go v0.2.3
 	github.com/Azure/azure-sdk-for-go v65.0.0+incompatible
-	github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0
+	github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.1-0.20230323231529-14c481f239ec
 	github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.2
 	github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0
 	github.com/Azure/azure-storage-blob-go v0.15.0

+ 2 - 2
go.sum

@@ -56,8 +56,8 @@ github.com/Azure/azure-pipeline-go v0.2.3 h1:7U9HBg1JFK3jHl5qmo4CTZKFTVgMwdFHMVt
 github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k=
 github.com/Azure/azure-sdk-for-go v65.0.0+incompatible h1:HzKLt3kIwMm4KeJYTdx9EbjRYTySD/t8i1Ee/W5EGXw=
 github.com/Azure/azure-sdk-for-go v65.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 h1:rTnT/Jrcm+figWlYz4Ixzt0SJVR2cMC8lvZcimipiEY=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.1-0.20230323231529-14c481f239ec h1:S83Dzhd3VLyvN2bgFI7/Lgk1etamk3Pk8QQhn3iXt4s=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.1-0.20230323231529-14c481f239ec/go.mod h1:IoxiGSzhL1QHFXa/mlAXCD+sUaP0rxg//yn2w/JH7wg=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.2 h1:uqM+VoHjVH6zdlkLF2b6O0ZANcHoj3rO0PoQ3jglUJA=
 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.2.2/go.mod h1:twTKAa1E6hLmSDjLhaCkbTMQKc7p/rNLU40rLxGEOCI=
 github.com/Azure/azure-sdk-for-go/sdk/internal v1.2.0 h1:leh5DwKv6Ihwi+h60uHtn6UWAxBbZ0q8DwQVMzf61zw=

+ 124 - 0
pkg/cloud/azurepricesheet/client.go

@@ -0,0 +1,124 @@
+package azurepricesheet
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net/http"
+	"net/url"
+	"time"
+
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
+	armruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
+)
+
+const (
+	moduleName    = "armconsumption"
+	moduleVersion = "v1.0.0"
+)
+
+// At the moment the consumption pricesheet download API is not a)
+// documented or b) supported by the SDK. This is an implementation of
+// a client in the style of the Azure go SDK - once the API is
+// supported this will be removed.
+
+// PriceSheetClient contains the methods for the PriceSheet group.
+// Don't use this type directly, use NewPriceSheetClient() instead.
+type PriceSheetClient struct {
+	host             string
+	billingAccountID string
+	pl               runtime.Pipeline
+}
+
+// NewClient creates a new instance of PriceSheetClient with the specified values.
+// billingAccountId - Azure Billing Account ID.
+// credential - used to authorize requests. Usually a credential from azidentity.
+// options - pass nil to accept the default values.
+func NewClient(billingAccountID string, credential azcore.TokenCredential, options *arm.ClientOptions) (*PriceSheetClient, error) {
+	if options == nil {
+		options = &arm.ClientOptions{}
+	}
+	ep := cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint
+	if c, ok := options.Cloud.Services[cloud.ResourceManager]; ok {
+		ep = c.Endpoint
+	}
+	pl, err := armruntime.NewPipeline(moduleName, moduleVersion, credential, runtime.PipelineOptions{}, options)
+	if err != nil {
+		return nil, err
+	}
+	client := &PriceSheetClient{
+		billingAccountID: billingAccountID,
+		host:             ep,
+		pl:               pl,
+	}
+	return client, nil
+}
+
+// BeginDownloadByBillingPeriod - requests a pricesheet for a specific billing period `yyyymm`.
+// Returns a Poller that will provide the download URL when the pricesheet is ready.
+// If the operation fails it returns an *azcore.ResponseError type.
+// Generated from API version 2022-06-01
+// billingPeriodName - Billing Period Name `yyyymm`.
+func (client *PriceSheetClient) BeginDownloadByBillingPeriod(ctx context.Context, billingPeriodName string) (*runtime.Poller[PriceSheetClientDownloadResponse], error) {
+	resp, err := client.downloadByBillingPeriodOperation(ctx, billingPeriodName)
+	if err != nil {
+		return nil, err
+	}
+	return runtime.NewPoller[PriceSheetClientDownloadResponse](resp, client.pl, nil)
+}
+
+type PriceSheetClientDownloadResponse struct {
+	ID         string                             `json:"id"`
+	Name       string                             `json:"name"`
+	StartTime  time.Time                          `json:"startTime"`
+	EndTime    time.Time                          `json:"endTime"`
+	Status     string                             `json:"status"`
+	Properties PriceSheetClientDownloadProperties `json:"properties"`
+}
+
+type PriceSheetClientDownloadProperties struct {
+	DownloadURL string `json:"downloadUrl"`
+	ValidTill   string `json:"validTill"`
+}
+
+func (client *PriceSheetClient) downloadByBillingPeriodOperation(ctx context.Context, billingPeriodName string) (*http.Response, error) {
+	req, err := client.downloadByBillingPeriodCreateRequest(ctx, billingPeriodName)
+	if err != nil {
+		return nil, err
+	}
+	resp, err := client.pl.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusAccepted) {
+		return nil, runtime.NewResponseError(resp)
+	}
+	return resp, nil
+}
+
+const downloadByBillingPeriodTemplate = "/providers/Microsoft.Billing/billingAccounts/%s/billingPeriods/%s/providers/Microsoft.Consumption/pricesheets/download"
+
+// downloadByBillingPeriodCreateRequest creates the DownloadByBillingPeriod request.
+func (client *PriceSheetClient) downloadByBillingPeriodCreateRequest(ctx context.Context, billingPeriodName string) (*policy.Request, error) {
+	if client.billingAccountID == "" {
+		return nil, errors.New("parameter client.billingAccountID cannot be empty")
+	}
+	if billingPeriodName == "" {
+		return nil, errors.New("parameter billingPeriodName cannot be empty")
+	}
+	urlPath := fmt.Sprintf(downloadByBillingPeriodTemplate, url.PathEscape(client.billingAccountID), url.PathEscape(billingPeriodName))
+	req, err := runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(client.host, urlPath))
+	if err != nil {
+		return nil, err
+	}
+	reqQP := req.Raw().URL.Query()
+	reqQP.Set("api-version", "2022-06-01")
+	reqQP.Set("ln", "en")
+	req.Raw().URL.RawQuery = reqQP.Encode()
+	req.Raw().Header["Accept"] = []string{"*/*"}
+	return req, nil
+}

+ 300 - 0
pkg/cloud/azurepricesheet/downloader.go

@@ -0,0 +1,300 @@
+package azurepricesheet
+
+import (
+	"bufio"
+	"context"
+	"encoding/csv"
+	"fmt"
+	"io"
+	"net/http"
+	"os"
+	"sort"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/Azure/azure-sdk-for-go/profiles/2020-09-01/commerce/mgmt/commerce"
+	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
+	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
+
+	"github.com/opencost/opencost/pkg/log"
+)
+
+type Downloader[T any] struct {
+	TenantID         string
+	ClientID         string
+	ClientSecret     string
+	BillingAccount   string
+	OfferID          string
+	ConvertMeterInfo func(info commerce.MeterInfo) (map[string]*T, error)
+}
+
+func (d *Downloader[T]) GetPricing(ctx context.Context) (map[string]*T, error) {
+	log.Infof("requesting pricesheet download link")
+	url, err := d.getPricesheetDownloadURL(ctx)
+	if err != nil {
+		return nil, fmt.Errorf("getting download URL: %w", err)
+	}
+	log.Infof("downloading pricesheet from %q", url)
+	data, err := d.saveData(ctx, url, "pricesheet")
+	if err != nil {
+		return nil, fmt.Errorf("saving pricesheet from %q: %w", url, err)
+	}
+	defer data.Close()
+
+	prices, err := d.readPricesheet(ctx, data)
+	if err != nil {
+		return nil, fmt.Errorf("reading pricesheet: %w", err)
+	}
+	log.Infof("loaded %d pricings from pricesheet", len(prices))
+	return prices, nil
+}
+
+func (d *Downloader[T]) getPricesheetDownloadURL(ctx context.Context) (string, error) {
+	cred, err := azidentity.NewClientSecretCredential(d.TenantID, d.ClientID, d.ClientSecret, nil)
+	if err != nil {
+		return "", fmt.Errorf("creating credential: %w", err)
+	}
+	client, err := NewClient(d.BillingAccount, cred, nil)
+	if err != nil {
+		return "", fmt.Errorf("creating pricesheet client: %w", err)
+	}
+	poller, err := client.BeginDownloadByBillingPeriod(ctx, currentBillingPeriod())
+	if err != nil {
+		return "", fmt.Errorf("beginning pricesheet download: %w", err)
+	}
+	resp, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
+		Frequency: 30 * time.Second,
+	})
+	if err != nil {
+		return "", fmt.Errorf("polling for pricesheet: %w", err)
+	}
+	return resp.Properties.DownloadURL, nil
+}
+
+func (d Downloader[T]) saveData(ctx context.Context, url, tempName string) (io.ReadCloser, error) {
+	// Download file from URL in response.
+	out, err := os.CreateTemp("", tempName)
+	if err != nil {
+		return nil, fmt.Errorf("creating %s temp file: %w", tempName, err)
+	}
+
+	resp, err := http.Get(url)
+	if err != nil {
+		return nil, fmt.Errorf("downloading: %w", err)
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		return nil, fmt.Errorf("unexpected HTTP status %d", resp.StatusCode)
+	}
+
+	if _, err := io.Copy(out, resp.Body); err != nil {
+		return nil, fmt.Errorf("reading response: %w", err)
+	}
+
+	_, err = out.Seek(0, io.SeekStart)
+	if err != nil {
+		return nil, fmt.Errorf("seeking to start of file: %w", err)
+	}
+
+	return &removeOnClose{File: out}, nil
+}
+
+type removeOnClose struct {
+	*os.File
+}
+
+func (r *removeOnClose) Close() error {
+	err := r.File.Close()
+	if err != nil {
+		return err
+	}
+	return os.Remove(r.Name())
+}
+
+func (d *Downloader[T]) readPricesheet(ctx context.Context, data io.Reader) (map[string]*T, error) {
+	// Avoid double-buffering.
+	buf, ok := (data).(*bufio.Reader)
+	if !ok {
+		buf = bufio.NewReader(data)
+	}
+
+	// The CSV file starts with two lines before the header without
+	// commas (so different numbers of fields as far as the CSV parser
+	// is concerned). Skip them before making the CSV reader so we
+	// still get the benefit of the row length checks after the
+	// header.
+	for i := 0; i < 2; i++ {
+		_, err := buf.ReadBytes('\n')
+		if err != nil {
+			return nil, fmt.Errorf("skipping preamble line %d: %w", i, err)
+		}
+	}
+	reader := csv.NewReader(buf)
+	reader.ReuseRecord = true
+
+	header, err := reader.Read()
+	if err != nil {
+		return nil, fmt.Errorf("reading header: %w", err)
+	}
+	if err := checkPricesheetHeader(header); err != nil {
+		return nil, err
+	}
+
+	units := make(map[string]bool)
+
+	results := make(map[string]*T)
+	lines := 2
+	for {
+		row, err := reader.Read()
+		if err == io.EOF {
+			break
+		}
+		lines++
+		if err != nil {
+			return nil, fmt.Errorf("reading line %d: %w", lines, err)
+		}
+
+		// Skip savings plan - we should be reporting based on the
+		// consumption price because we don't know whether the user is
+		// using a savings plan or over their threshold.
+		if row[pricesheetPriceType] == "Savings Plan" || row[pricesheetOfferID] != d.OfferID {
+			continue
+		}
+
+		// TODO: Creating a meter info for each record will cause a
+		// lot of GC churn - is it worth reusing one meter info instead?
+		meterInfo, err := makeMeterInfo(row)
+		if err != nil {
+			log.Warnf("making meter info (line %d): %v", lines, err)
+			continue
+		}
+
+		pricings, err := d.ConvertMeterInfo(meterInfo)
+		if err != nil {
+			log.Warnf("converting meter to pricings (line %d): %v", lines, err)
+			continue
+		}
+
+		if pricings != nil {
+			units[*meterInfo.Unit] = true
+		}
+
+		for key, pricing := range pricings {
+			results[key] = pricing
+		}
+	}
+
+	if len(results) == 0 {
+		return nil, fmt.Errorf("no matching pricing from price sheet")
+	}
+
+	// Keep track of units seen so we can detect if there are any that
+	// need handling.
+	allUnits := make([]string, 0, len(units))
+	for unit := range units {
+		allUnits = append(allUnits, unit)
+	}
+	sort.Strings(allUnits)
+	log.Infof("all units in pricesheet: %s", strings.Join(allUnits, ", "))
+
+	return results, nil
+}
+
+func checkPricesheetHeader(header []string) error {
+	if len(header) < len(pricesheetCols) {
+		return fmt.Errorf("too few header columns: got %d, expected %d", len(header), len(pricesheetCols))
+	}
+	for col, name := range pricesheetCols {
+		if !strings.EqualFold(header[col], name) {
+			return fmt.Errorf("unexpected header at col %d %q, expected %q", col, header[col], name)
+		}
+	}
+	return nil
+}
+
+func makeMeterInfo(row []string) (commerce.MeterInfo, error) {
+	price, err := strconv.ParseFloat(row[pricesheetUnitPrice], 64)
+	if err != nil {
+		return commerce.MeterInfo{}, fmt.Errorf("parsing unit price: %w", err)
+	}
+	newPrice, unit := normalisePrice(price, row[pricesheetUnit])
+	return commerce.MeterInfo{
+		MeterName:        ptr(row[pricesheetMeterName]),
+		MeterCategory:    ptr(row[pricesheetMeterCategory]),
+		MeterSubCategory: ptr(row[pricesheetMeterSubCategory]),
+		Unit:             &unit,
+		MeterRegion:      ptr(row[pricesheetMeterRegion]),
+		MeterRates:       map[string]*float64{"0": &newPrice},
+	}, nil
+}
+
+var pricesheetCols = []string{
+	"Meter ID",
+	"Meter name",
+	"Meter category",
+	"Meter sub-category",
+	"Meter region",
+	"Unit",
+	"Unit of measure",
+	"Part number",
+	"Unit price",
+	"Currency code",
+	"Included quantity",
+	"Offer Id",
+	"Term",
+	"Price type",
+}
+
+const (
+	pricesheetMeterID          = 0
+	pricesheetMeterName        = 1
+	pricesheetMeterCategory    = 2
+	pricesheetMeterSubCategory = 3
+	pricesheetMeterRegion      = 4
+	pricesheetUnit             = 5
+	pricesheetUnitPrice        = 8
+	pricesheetCurrencyCode     = 9
+	pricesheetOfferID          = 11
+	pricesheetPriceType        = 13
+)
+
+func currentBillingPeriod() string {
+	return time.Now().Format("200601")
+}
+
+func ptr[T any](v T) *T {
+	return &v
+}
+
+// conversions lists all the units seen from the price sheet for
+// prices we're interested in with factors to the corresponding units
+// in the rate card.
+var conversions = map[string]struct {
+	divisor float64
+	unit    string
+}{
+	"1 /Month":       {divisor: 1, unit: "1 /Month"},
+	"1 Hour":         {divisor: 1, unit: "1 Hour"},
+	"1 PiB/Hour":     {divisor: 1_000_000, unit: "1 GiB/Hour"},
+	"10 /Month":      {divisor: 10, unit: "1 /Month"},
+	"10 Hours":       {divisor: 10, unit: "1 Hour"},
+	"100 /Month":     {divisor: 100, unit: "1 /Month"},
+	"100 GB/Month":   {divisor: 100, unit: "1 GB/Month"},
+	"100 Hours":      {divisor: 100, unit: "1 Hour"},
+	"100 TiB/Hour":   {divisor: 100_000, unit: "1 GiB/Hour"},
+	"1000 Hours":     {divisor: 1000, unit: "1 Hour"},
+	"10000 Hours":    {divisor: 10_000, unit: "1 Hour"},
+	"100000 /Hour":   {divisor: 100_000, unit: "1 /Hour"},
+	"1000000 /Hour":  {divisor: 1_000_000, unit: "1 /Hour"},
+	"10000000 /Hour": {divisor: 10_000_000, unit: "1 /Hour"},
+}
+
+func normalisePrice(price float64, unit string) (float64, string) {
+	if conv, ok := conversions[unit]; ok {
+		return price / conv.divisor, conv.unit
+	}
+
+	return price, unit
+}

+ 101 - 0
pkg/cloud/azurepricesheet/downloader_test.go

@@ -0,0 +1,101 @@
+package azurepricesheet
+
+import (
+	"context"
+	"fmt"
+	"strings"
+	"testing"
+
+	"github.com/Azure/azure-sdk-for-go/profiles/2020-09-01/commerce/mgmt/commerce"
+	"github.com/stretchr/testify/require"
+)
+
+func TestDownloader(t *testing.T) {
+	d := Downloader[fakePricing]{
+		TenantID:         "test-tenant-id",
+		ClientID:         "test-client-id",
+		ClientSecret:     "test-client-secret",
+		BillingAccount:   "test-billing-account",
+		OfferID:          "my-offer-id",
+		ConvertMeterInfo: convertMeter,
+	}
+
+	t.Run("read prices", func(t *testing.T) {
+		results, err := d.readPricesheet(context.Background(), strings.NewReader(pricesheetData))
+		require.NoError(t, err)
+
+		// Units and prices are normalised.
+		// Info for saving plans and other offers is skipped.
+		expected := map[string]*fakePricing{
+			"DC96as_v4": {price: "10.505", unit: "1 Hour"},
+			"DC2as_v4":  {price: "0.219", unit: "1 Hour"},
+			"VM1":       {price: "1.0", unit: "1 Hour"},
+			"VM2":       {price: "2.0", unit: "1 Hour"},
+		}
+		require.Equal(t, expected, results)
+	})
+
+	t.Run("bad header", func(t *testing.T) {
+		data := "\n\nMeter ID,Meter name,Meter category,Something else,,,,,,,,,,,,,,\n"
+		_, err := d.readPricesheet(context.Background(), strings.NewReader(data))
+		require.ErrorContains(t, err, `unexpected header at col 3 "Something else", expected "Meter sub-category"`)
+	})
+
+	t.Run("short header", func(t *testing.T) {
+		data := "\n\nMeter ID, Meter name, Meter category, Meter sub-category\n"
+		_, err := d.readPricesheet(context.Background(), strings.NewReader(data))
+		require.ErrorContains(t, err, "too few header columns: got 4, expected 14")
+	})
+
+	t.Run("no matching prices", func(t *testing.T) {
+		d := Downloader[fakePricing]{
+			TenantID:       "test-tenant-id",
+			ClientID:       "test-client-id",
+			ClientSecret:   "test-client-secret",
+			BillingAccount: "test-billing-account",
+			OfferID:        "my-offer-id",
+			ConvertMeterInfo: func(commerce.MeterInfo) (map[string]*fakePricing, error) {
+				return nil, nil
+			},
+		}
+		_, err := d.readPricesheet(context.Background(), strings.NewReader(pricesheetData))
+		require.ErrorContains(t, err, "no matching pricing from price sheet")
+	})
+}
+
+func convertMeter(info commerce.MeterInfo) (map[string]*fakePricing, error) {
+	switch *info.MeterName {
+	case "skip-this":
+		return nil, nil
+	case "multiple-prices":
+		return map[string]*fakePricing{
+			"VM1": {price: "1.0", unit: "1 Hour"},
+			"VM2": {price: "2.0", unit: "1 Hour"},
+		}, nil
+	case "error":
+		return nil, fmt.Errorf("there was an error handling this row!")
+	default:
+		return map[string]*fakePricing{
+			*info.MeterName: {price: fmt.Sprintf("%0.3f", *info.MeterRates["0"]), unit: *info.Unit},
+		}, nil
+	}
+}
+
+type fakePricing struct {
+	price string
+	unit  string
+}
+
+const pricesheetData = `Price Sheet Report for billing period - 202304
+
+Meter ID,Meter name,Meter category,Meter sub-category,Meter region,Unit,Unit of measure,Part number,Unit price,Currency code,Included quantity,Offer Id,Term,Price type
+d4236f8f-3ba6-5a9a-8c6b-14556538c44c,DC96as_v4,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70822,105.050000000000000,USD,0.00,my-offer-id,,Consumption
+d4236f8f-3ba6-5a9a-8c6b-14556538c44c,DC96as_v4,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70831,60.890000000000000,USD,0.00,other-offer-id,,Consumption
+e47a2c4c-4dc4-55d5-a8d7-ec5b1dcc9c08,DC2as_v4,Virtual Machines,DCasv4 Series,US East,100 Hours,100 Hours,AAF-70890,21.900000000000000,USD,0.000,my-offer-id,,Consumption
+e47a2c4c-4dc4-55d5-a8d7-ec5b1dcc9c08,DC2as_v4,Virtual Machines,DCasv4 Series,US East,100 Hours,100 Hours,AAF-70886,12.700000000000000,USD,0.000,other-offer-id,,Consumption
+cb8d72c0-2b02-5b41-9ac9-2809c04f17ff,DC16as_v4,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70911,17.510000000000000,USD,0.00,my-offer-id,,Savings Plan
+cb8d72c0-2b02-5b41-9ac9-2809c04f17ff,DC16as_v4,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70910,10.150000000000000,USD,0.00,other-offer-id,,Consumption
+d4236f8f-3ba6-5a9a-8c6b-14556538c44c,skip-this,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70822,105.050000000000000,USD,0.00,my-offer-id,,Consumption
+d4236f8f-3ba6-5a9a-8c6b-14556538c44c,multiple-prices,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70822,105.050000000000000,USD,0.00,my-offer-id,,Consumption
+d4236f8f-3ba6-5a9a-8c6b-14556538c44c,error,Virtual Machines,DCasv4 Series,US East,10 Hours,10 Hours,AAF-70822,105.050000000000000,USD,0.00,my-offer-id,,Consumption
+`

+ 199 - 104
pkg/cloud/azureprovider.go

@@ -13,8 +13,6 @@ import (
 	"sync"
 	"time"
 
-	"github.com/opencost/opencost/pkg/kubecost"
-
 	"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2021-11-01/compute"
 	"github.com/Azure/azure-sdk-for-go/services/preview/commerce/mgmt/2015-06-01-preview/commerce"
 	"github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2016-06-01/subscriptions"
@@ -22,13 +20,17 @@ import (
 	"github.com/Azure/go-autorest/autorest"
 	"github.com/Azure/go-autorest/autorest/azure"
 	"github.com/Azure/go-autorest/autorest/azure/auth"
+
+	pricesheet "github.com/opencost/opencost/pkg/cloud/azurepricesheet"
 	"github.com/opencost/opencost/pkg/clustercache"
 	"github.com/opencost/opencost/pkg/env"
+	"github.com/opencost/opencost/pkg/kubecost"
 	"github.com/opencost/opencost/pkg/log"
 	"github.com/opencost/opencost/pkg/util"
 	"github.com/opencost/opencost/pkg/util/fileutil"
 	"github.com/opencost/opencost/pkg/util/json"
 	"github.com/opencost/opencost/pkg/util/timeutil"
+
 	v1 "k8s.io/api/core/v1"
 )
 
@@ -205,7 +207,7 @@ func getRegions(service string, subscriptionsClient subscriptions.Client, provid
 						if loc, ok := allLocations[displName]; ok {
 							supLocations[loc] = displName
 						} else {
-							log.Warnf("unsupported cloud region %s", loc)
+							log.Warnf("unsupported cloud region %q", loc)
 						}
 					}
 					break
@@ -223,7 +225,7 @@ func getRegions(service string, subscriptionsClient subscriptions.Client, provid
 						if loc, ok := allLocations[displName]; ok {
 							supLocations[loc] = displName
 						} else {
-							log.Warnf("unsupported cloud region %s", loc)
+							log.Warnf("unsupported cloud region %q", loc)
 						}
 					}
 					break
@@ -328,7 +330,7 @@ func toRegionID(meterRegion string, regions map[string]string) (string, error) {
 			return regionID, nil
 		}
 	}
-	return "", fmt.Errorf("Couldn't find region")
+	return "", fmt.Errorf("Couldn't find region %q", meterRegion)
 }
 
 // azure has very inconsistent naming standards between display names from the rate card api and display names from the regions api
@@ -395,7 +397,9 @@ type Azure struct {
 	Clientset                      clustercache.ClusterCache
 	Config                         *ProviderConfig
 	serviceAccountChecks           *ServiceAccountChecks
-	RateCardPricingError           error
+	pricingSource                  string
+	rateCardPricingError           error
+	priceSheetPricingError         error
 	clusterAccountID               string
 	clusterRegion                  string
 	loadedAzureSecret              bool
@@ -779,10 +783,19 @@ func (az *Azure) DownloadPricingData() error {
 
 	config, err := az.GetConfig()
 	if err != nil {
-		az.RateCardPricingError = err
+		az.rateCardPricingError = err
 		return err
 	}
 
+	envBillingAccount := env.GetAzureBillingAccount()
+	if envBillingAccount != "" {
+		config.AzureBillingAccount = envBillingAccount
+	}
+	envOfferID := env.GetAzureOfferID()
+	if envOfferID != "" {
+		config.AzureOfferDurableID = envOfferID
+	}
+
 	// Load the service provider keys
 	subscriptionID, clientID, clientSecret, tenantID := az.getAzureRateCardAuth(false, config)
 	config.AzureSubscriptionID = subscriptionID
@@ -798,7 +811,7 @@ func (az *Azure) DownloadPricingData() error {
 		credentialsConfig := NewClientCredentialsConfig(config.AzureClientID, config.AzureClientSecret, config.AzureTenantID, azureEnv)
 		a, err := credentialsConfig.Authorizer()
 		if err != nil {
-			az.RateCardPricingError = err
+			az.rateCardPricingError = err
 			return err
 		}
 		authorizer = a
@@ -810,7 +823,7 @@ func (az *Azure) DownloadPricingData() error {
 		if err != nil {
 			a, err := auth.NewAuthorizerFromFile(azureEnv.ResourceManagerEndpoint)
 			if err != nil {
-				az.RateCardPricingError = err
+				az.rateCardPricingError = err
 				return err
 			}
 			authorizer = a
@@ -832,14 +845,14 @@ func (az *Azure) DownloadPricingData() error {
 	result, err := rcClient.Get(context.TODO(), rateCardFilter)
 	if err != nil {
 		log.Warnf("Error in pricing download query from API")
-		az.RateCardPricingError = err
+		az.rateCardPricingError = err
 		return err
 	}
 
 	regions, err := getRegions("compute", sClient, providersClient, config.AzureSubscriptionID)
 	if err != nil {
 		log.Warnf("Error in pricing download regions from API")
-		az.RateCardPricingError = err
+		az.rateCardPricingError = err
 		return err
 	}
 
@@ -847,107 +860,166 @@ func (az *Azure) DownloadPricingData() error {
 	allPrices := make(map[string]*AzurePricing)
 
 	for _, v := range *result.Meters {
-		meterName := *v.MeterName
-		meterRegion := *v.MeterRegion
-		meterCategory := *v.MeterCategory
-		meterSubCategory := *v.MeterSubCategory
-
-		region, err := toRegionID(meterRegion, regions)
+		pricings, err := convertMeterToPricings(v, regions, baseCPUPrice)
 		if err != nil {
+			log.Warnf("converting meter to pricings: %s", err.Error())
 			continue
 		}
+		for key, pricing := range pricings {
+			allPrices[key] = pricing
+		}
+	}
+	addAzureFilePricing(allPrices, regions)
 
-		if !strings.Contains(meterSubCategory, "Windows") {
-
-			if strings.Contains(meterCategory, "Storage") {
-				if strings.Contains(meterSubCategory, "HDD") || strings.Contains(meterSubCategory, "SSD") || strings.Contains(meterSubCategory, "Premium Files") {
-					var storageClass string = ""
-					if strings.Contains(meterName, "P4 ") {
-						storageClass = AzureDiskPremiumSSDStorageClass
-					} else if strings.Contains(meterName, "E4 ") {
-						storageClass = AzureDiskStandardSSDStorageClass
-					} else if strings.Contains(meterName, "S4 ") {
-						storageClass = AzureDiskStandardStorageClass
-					} else if strings.Contains(meterName, "LRS Provisioned") {
-						storageClass = AzureFilePremiumStorageClass
-					}
-
-					if storageClass != "" {
-						var priceInUsd float64
-
-						if len(v.MeterRates) < 1 {
-							log.Warnf("missing rate info %+v", map[string]interface{}{"MeterSubCategory": *v.MeterSubCategory, "region": region})
-							continue
-						}
-						for _, rate := range v.MeterRates {
-							priceInUsd += *rate
-						}
-						// rate is in disk per month, resolve price per hour, then GB per hour
-						pricePerHour := priceInUsd / 730.0 / 32.0
-						priceStr := fmt.Sprintf("%f", pricePerHour)
-
-						key := region + "," + storageClass
-						log.Debugf("Adding PV.Key: %s, Cost: %s", key, priceStr)
-						allPrices[key] = &AzurePricing{
-							PV: &PV{
-								Cost:   priceStr,
-								Region: region,
-							},
-						}
-					}
-				}
+	az.Pricing = allPrices
+	az.pricingSource = rateCardPricingSource
+	az.rateCardPricingError = nil
+
+	// If we've got a billing account set, kick off downloading the custom pricing data.
+	if config.AzureBillingAccount != "" {
+		downloader := pricesheet.Downloader[AzurePricing]{
+			TenantID:       config.AzureTenantID,
+			ClientID:       config.AzureClientID,
+			ClientSecret:   config.AzureClientSecret,
+			BillingAccount: config.AzureBillingAccount,
+			OfferID:        config.AzureOfferDurableID,
+			ConvertMeterInfo: func(meterInfo commerce.MeterInfo) (map[string]*AzurePricing, error) {
+				return convertMeterToPricings(meterInfo, regions, baseCPUPrice)
+			},
+		}
+		// The price sheet can take 5 minutes to generate, so we don't
+		// want to hang onto the lock while we're waiting for it.
+		go func() {
+			ctx := context.Background()
+			allPrices, err := downloader.GetPricing(ctx)
+
+			az.DownloadPricingDataLock.Lock()
+			defer az.DownloadPricingDataLock.Unlock()
+			if err != nil {
+				log.Errorf("Error downloading Azure price sheet: %s", err)
+				az.priceSheetPricingError = err
+				return
 			}
+			addAzureFilePricing(allPrices, regions)
+			az.Pricing = allPrices
+			az.pricingSource = priceSheetPricingSource
+			az.priceSheetPricingError = nil
+		}()
+	}
 
-			if strings.Contains(meterCategory, "Virtual Machines") {
-
-				usageType := ""
-				if !strings.Contains(meterName, "Low Priority") {
-					usageType = "ondemand"
-				} else {
-					usageType = "preemptible"
-				}
+	return nil
+}
 
-				var instanceTypes []string
-				name := strings.TrimSuffix(meterName, " Low Priority")
-				instanceType := strings.Split(name, "/")
-				for _, it := range instanceType {
-					if strings.Contains(meterSubCategory, "Promo") {
-						it = it + " Promo"
-					}
-					instanceTypes = append(instanceTypes, strings.Replace(it, " ", "_", 1))
-				}
+func convertMeterToPricings(info commerce.MeterInfo, regions map[string]string, baseCPUPrice string) (map[string]*AzurePricing, error) {
+	meterName := *info.MeterName
+	meterRegion := *info.MeterRegion
+	meterCategory := *info.MeterCategory
+	meterSubCategory := *info.MeterSubCategory
 
-				instanceTypes = transformMachineType(meterSubCategory, instanceTypes)
-				if strings.Contains(name, "Expired") {
-					instanceTypes = []string{}
-				}
+	region, err := toRegionID(meterRegion, regions)
+	if err != nil {
+		// Skip this meter if we don't recognize the region.
+		return nil, nil
+	}
+
+	if strings.Contains(meterSubCategory, "Windows") {
+		// This meter doesn't correspond to any pricings.
+		return nil, nil
+	}
+
+	if strings.Contains(meterCategory, "Storage") {
+		if strings.Contains(meterSubCategory, "HDD") || strings.Contains(meterSubCategory, "SSD") || strings.Contains(meterSubCategory, "Premium Files") {
+			var storageClass string = ""
+			if strings.Contains(meterName, "P4 ") {
+				storageClass = AzureDiskPremiumSSDStorageClass
+			} else if strings.Contains(meterName, "E4 ") {
+				storageClass = AzureDiskStandardSSDStorageClass
+			} else if strings.Contains(meterName, "S4 ") {
+				storageClass = AzureDiskStandardStorageClass
+			} else if strings.Contains(meterName, "LRS Provisioned") {
+				storageClass = AzureFilePremiumStorageClass
+			}
 
+			if storageClass != "" {
 				var priceInUsd float64
 
-				if len(v.MeterRates) < 1 {
-					log.Warnf("missing rate info %+v", map[string]interface{}{"MeterSubCategory": *v.MeterSubCategory, "region": region})
-					continue
+				if len(info.MeterRates) < 1 {
+					return nil, fmt.Errorf("missing rate info %+v", map[string]interface{}{"MeterSubCategory": *info.MeterSubCategory, "region": region})
 				}
-				for _, rate := range v.MeterRates {
+				for _, rate := range info.MeterRates {
 					priceInUsd += *rate
 				}
-				priceStr := fmt.Sprintf("%f", priceInUsd)
-				for _, instanceType := range instanceTypes {
-
-					key := fmt.Sprintf("%s,%s,%s", region, instanceType, usageType)
-
-					allPrices[key] = &AzurePricing{
-						Node: &Node{
-							Cost:         priceStr,
-							BaseCPUPrice: baseCPUPrice,
-							UsageType:    usageType,
+				// rate is in disk per month, resolve price per hour, then GB per hour
+				pricePerHour := priceInUsd / 730.0 / 32.0
+				priceStr := fmt.Sprintf("%f", pricePerHour)
+
+				key := region + "," + storageClass
+				log.Debugf("Adding PV.Key: %s, Cost: %s", key, priceStr)
+				return map[string]*AzurePricing{
+					key: {
+						PV: &PV{
+							Cost:   priceStr,
+							Region: region,
 						},
-					}
-				}
+					},
+				}, nil
 			}
 		}
 	}
 
+	if !strings.Contains(meterCategory, "Virtual Machines") {
+		return nil, nil
+	}
+
+	usageType := ""
+	if !strings.Contains(meterName, "Low Priority") {
+		usageType = "ondemand"
+	} else {
+		usageType = "preemptible"
+	}
+
+	var instanceTypes []string
+	name := strings.TrimSuffix(meterName, " Low Priority")
+	instanceType := strings.Split(name, "/")
+	for _, it := range instanceType {
+		if strings.Contains(meterSubCategory, "Promo") {
+			it = it + " Promo"
+		}
+		instanceTypes = append(instanceTypes, strings.Replace(it, " ", "_", 1))
+	}
+
+	instanceTypes = transformMachineType(meterSubCategory, instanceTypes)
+	if strings.Contains(name, "Expired") {
+		instanceTypes = []string{}
+	}
+
+	var priceInUsd float64
+
+	if len(info.MeterRates) < 1 {
+		return nil, fmt.Errorf("missing rate info %+v", map[string]interface{}{"MeterSubCategory": *info.MeterSubCategory, "region": region})
+	}
+	for _, rate := range info.MeterRates {
+		priceInUsd += *rate
+	}
+	priceStr := fmt.Sprintf("%f", priceInUsd)
+	results := make(map[string]*AzurePricing)
+	for _, instanceType := range instanceTypes {
+
+		key := fmt.Sprintf("%s,%s,%s", region, instanceType, usageType)
+		pricing := &AzurePricing{
+			Node: &Node{
+				Cost:         priceStr,
+				BaseCPUPrice: baseCPUPrice,
+				UsageType:    usageType,
+			},
+		}
+		results[key] = pricing
+	}
+	return results, nil
+
+}
+
+func addAzureFilePricing(prices map[string]*AzurePricing, regions map[string]string) {
 	// There is no easy way of supporting Standard Azure-File, because it's billed per used GB
 	// this will set the price to "0" as a workaround to not spam with `Persistent Volume pricing not found for` error
 	// check https://github.com/opencost/opencost/issues/159 for more information (same problem on AWS)
@@ -955,17 +1027,13 @@ func (az *Azure) DownloadPricingData() error {
 	for region := range regions {
 		key := region + "," + AzureFileStandardStorageClass
 		log.Debugf("Adding PV.Key: %s, Cost: %s", key, zeroPrice)
-		allPrices[key] = &AzurePricing{
+		prices[key] = &AzurePricing{
 			PV: &PV{
 				Cost:   zeroPrice,
 				Region: region,
 			},
 		}
 	}
-
-	az.Pricing = allPrices
-	az.RateCardPricingError = nil
-	return nil
 }
 
 // determineCloudByRegion uses region name to pick the correct Cloud Environment for the azure provider to use
@@ -1205,7 +1273,7 @@ func (az *Azure) getDisks() ([]*compute.Disk, error) {
 		credentialsConfig := NewClientCredentialsConfig(config.AzureClientID, config.AzureClientSecret, config.AzureTenantID, azureEnv)
 		a, err := credentialsConfig.Authorizer()
 		if err != nil {
-			az.RateCardPricingError = err
+			az.rateCardPricingError = err
 			return nil, err
 		}
 		authorizer = a
@@ -1217,7 +1285,7 @@ func (az *Azure) getDisks() ([]*compute.Disk, error) {
 		if err != nil {
 			a, err := auth.NewAuthorizerFromFile(azureEnv.ResourceManagerEndpoint)
 			if err != nil {
-				az.RateCardPricingError = err
+				az.rateCardPricingError = err
 				return nil, err
 			}
 			authorizer = a
@@ -1472,18 +1540,23 @@ func (az *Azure) ServiceAccountStatus() *ServiceAccountStatus {
 	return az.serviceAccountChecks.getStatus()
 }
 
-const rateCardPricingSource = "Rate Card API"
+const (
+	rateCardPricingSource   = "Rate Card API"
+	priceSheetPricingSource = "Price Sheet API"
+)
 
 // PricingSourceStatus returns the status of the rate card api
 func (az *Azure) PricingSourceStatus() map[string]*PricingSource {
+	az.DownloadPricingDataLock.Lock()
+	defer az.DownloadPricingDataLock.Unlock()
 	sources := make(map[string]*PricingSource)
 	errMsg := ""
-	if az.RateCardPricingError != nil {
-		errMsg = az.RateCardPricingError.Error()
+	if az.rateCardPricingError != nil {
+		errMsg = az.rateCardPricingError.Error()
 	}
 	rcps := &PricingSource{
 		Name:    rateCardPricingSource,
-		Enabled: true,
+		Enabled: az.pricingSource == rateCardPricingSource,
 		Error:   errMsg,
 	}
 	if rcps.Error != "" {
@@ -1494,7 +1567,29 @@ func (az *Azure) PricingSourceStatus() map[string]*PricingSource {
 	} else {
 		rcps.Available = true
 	}
+
+	errMsg = ""
+	if az.priceSheetPricingError != nil {
+		errMsg = az.priceSheetPricingError.Error()
+	}
+	psps := &PricingSource{
+		Name:    priceSheetPricingSource,
+		Enabled: az.pricingSource == priceSheetPricingSource,
+		Error:   errMsg,
+	}
+	if psps.Error != "" {
+		psps.Available = false
+	} else if len(az.Pricing) == 0 {
+		psps.Error = "No Pricing Data Available"
+		psps.Available = false
+	} else if env.GetAzureBillingAccount() == "" {
+		psps.Error = "No Azure Billing Account ID"
+		psps.Available = false
+	} else {
+		psps.Available = true
+	}
 	sources[rateCardPricingSource] = rcps
+	sources[priceSheetPricingSource] = psps
 	return sources
 }
 

+ 59 - 0
pkg/cloud/azureprovider_test.go

@@ -2,6 +2,9 @@ package cloud
 
 import (
 	"testing"
+
+	"github.com/Azure/azure-sdk-for-go/services/preview/commerce/mgmt/2015-06-01-preview/commerce"
+	"github.com/stretchr/testify/require"
 )
 
 func TestParseAzureSubscriptionID(t *testing.T) {
@@ -34,3 +37,59 @@ func TestParseAzureSubscriptionID(t *testing.T) {
 		}
 	}
 }
+
+func TestConvertMeterToPricings(t *testing.T) {
+	regions := map[string]string{
+		"useast":             "US East",
+		"japanwest":          "Japan West",
+		"australiasoutheast": "Australia Southeast",
+		"norwaywest":         "Norway West",
+	}
+	baseCPUPrice := "0.30000"
+
+	meterInfo := func(category, subcategory, name, region string, rate float64) commerce.MeterInfo {
+		return commerce.MeterInfo{
+			MeterCategory:    &category,
+			MeterSubCategory: &subcategory,
+			MeterName:        &name,
+			MeterRegion:      &region,
+			MeterRates:       map[string]*float64{"0": &rate},
+		}
+	}
+
+	t.Run("windows", func(t *testing.T) {
+		info := meterInfo("Virtual Machines", "D2 Series Windows", "D2s v3", "AU Southeast", 0.3)
+		results, err := convertMeterToPricings(info, regions, baseCPUPrice)
+		require.NoError(t, err)
+		require.Nil(t, results)
+	})
+
+	t.Run("storage", func(t *testing.T) {
+		info := meterInfo("Storage", "Some SSD type", "P4 are good", "US East", 2000)
+		results, err := convertMeterToPricings(info, regions, baseCPUPrice)
+		require.NoError(t, err)
+
+		expected := map[string]*AzurePricing{
+			"useast,premium_ssd": {
+				PV: &PV{Cost: "0.085616", Region: "useast"},
+			},
+		}
+		require.Equal(t, expected, results)
+	})
+
+	t.Run("virtual machines", func(t *testing.T) {
+		info := meterInfo("Virtual Machines", "Eav4/Easv4 Series", "E96a v4/E96as v4 Low Priority", "JA West", 10)
+		results, err := convertMeterToPricings(info, regions, baseCPUPrice)
+		require.NoError(t, err)
+
+		expected := map[string]*AzurePricing{
+			"japanwest,Standard_E96a_v4,preemptible": {
+				Node: &Node{Cost: "10.000000", BaseCPUPrice: "0.30000", UsageType: "preemptible"},
+			},
+			"japanwest,Standard_E96as_v4,preemptible": {
+				Node: &Node{Cost: "10.000000", BaseCPUPrice: "0.30000", UsageType: "preemptible"},
+			},
+		}
+		require.Equal(t, expected, results)
+	})
+}

+ 1 - 0
pkg/cloud/provider.go

@@ -211,6 +211,7 @@ type CustomPricing struct {
 	AzureClientSecret            string `json:"azureClientSecret"`
 	AzureTenantID                string `json:"azureTenantID"`
 	AzureBillingRegion           string `json:"azureBillingRegion"`
+	AzureBillingAccount          string `json:"azureBillingAccount"`
 	AzureOfferDurableID          string `json:"azureOfferDurableID"`
 	AzureStorageSubscriptionID   string `json:"azureStorageSubscriptionID"`
 	AzureStorageAccount          string `json:"azureStorageAccount"`

+ 18 - 0
pkg/env/costmodelenv.go

@@ -18,6 +18,9 @@ const (
 	AlibabaAccessKeyIDEnvVar     = "ALIBABA_ACCESS_KEY_ID"
 	AlibabaAccessKeySecretEnvVar = "ALIBABA_SECRET_ACCESS_KEY"
 
+	AzureOfferIDEnvVar        = "AZURE_OFFER_ID"
+	AzureBillingAccountEnvVar = "AZURE_BILLING_ACCOUNT"
+
 	KubecostNamespaceEnvVar        = "KUBECOST_NAMESPACE"
 	PodNameEnvVar                  = "POD_NAME"
 	ClusterIDEnvVar                = "CLUSTER_ID"
@@ -228,6 +231,21 @@ func GetAlibabaAccessKeySecret() string {
 	return Get(AlibabaAccessKeySecretEnvVar, "")
 }
 
+// GetAzureOfferID returns the environment variable value for AzureOfferIDEnvVar which represents
+// the Azure offer ID for determining prices.
+func GetAzureOfferID() string {
+	return Get(AzureOfferIDEnvVar, "")
+}
+
+// GetAzureBillingAccount returns the environment variable value for
+// AzureBillingAccountEnvVar which represents the Azure billing
+// account for determining prices. If this is specified
+// customer-specific prices will be downloaded from the consumption
+// price sheet API.
+func GetAzureBillingAccount() string {
+	return Get(AzureBillingAccountEnvVar, "")
+}
+
 // GetKubecostNamespace returns the environment variable value for KubecostNamespaceEnvVar which
 // represents the namespace the cost model exists in.
 func GetKubecostNamespace() string {