Jelajahi Sumber

initial changes for supporting DefaultAzureCredentials

Signed-off-by: David Soff <david@soff.nl>
David Soff 2 tahun lalu
induk
melakukan
5adde3b68a

+ 66 - 4
pkg/cloud/azure/authorizer.go

@@ -4,15 +4,17 @@ import (
 	"encoding/json"
 	"fmt"
 
-	"github.com/Azure/azure-storage-blob-go/azblob"
+	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
+	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
 	"github.com/opencost/opencost/pkg/cloud"
 )
 
 const AccessKeyAuthorizerType = "AzureAccessKey"
+const DefaultAzureCredentialHolderAuthorizerType = "DefaultAzureCredentialHolder"
 
 type Authorizer interface {
 	cloud.Authorizer
-	GetBlobCredentials() (azblob.Credential, error)
+	GetBlobClient(urlTemplate string) (*azblob.Client, error)
 }
 
 // SelectAuthorizerByType is an implementation of AuthorizerSelectorFn and acts as a register for Authorizer types
@@ -20,6 +22,8 @@ func SelectAuthorizerByType(typeStr string) (Authorizer, error) {
 	switch typeStr {
 	case AccessKeyAuthorizerType:
 		return &AccessKey{}, nil
+	case DefaultAzureCredentialHolderAuthorizerType:
+		return &DefaultAzureCredentialHolder{}, nil
 	default:
 		return nil, fmt.Errorf("azure: provider authorizer type '%s' is not valid", typeStr)
 	}
@@ -74,7 +78,65 @@ func (ak *AccessKey) Sanitize() cloud.Config {
 	}
 }
 
-func (ak *AccessKey) GetBlobCredentials() (azblob.Credential, error) {
+func (ak *AccessKey) GetBlobClient(urlTemplate string) (*azblob.Client, error) {
 	// Create a default request pipeline using your storage account name and account key.
-	return azblob.NewSharedKeyCredential(ak.Account, ak.AccessKey)
+	serviceURL := fmt.Sprintf(urlTemplate, ak.Account, "")
+
+	credential, err := azblob.NewSharedKeyCredential(ak.Account, ak.AccessKey)
+	if err != nil {
+		return nil, err
+	}
+	client, err := azblob.NewClientWithSharedKeyCredential(serviceURL, credential, nil)
+	return client, err
+}
+
+type DefaultAzureCredentialHolder struct {
+	Account string `json:"account"`
+}
+
+func (dac *DefaultAzureCredentialHolder) MarshalJSON() ([]byte, error) {
+	fmap := make(map[string]any, 2)
+	fmap[cloud.AuthorizerTypeProperty] = DefaultAzureCredentialHolderAuthorizerType
+	fmap["account"] = dac.Account
+	return json.Marshal(fmap)
+}
+
+func (dac *DefaultAzureCredentialHolder) Validate() error {
+	if dac.Account == "" {
+		return fmt.Errorf("AccessKey: missing account")
+	}
+	return nil
+}
+
+func (dac *DefaultAzureCredentialHolder) Equals(config cloud.Config) bool {
+	if config == nil {
+		return false
+	}
+	thatConfig, ok := config.(*DefaultAzureCredentialHolder)
+	if !ok {
+		return false
+	}
+
+	if dac.Account != thatConfig.Account {
+		return false
+	}
+
+	return true
+}
+
+func (dac *DefaultAzureCredentialHolder) Sanitize() cloud.Config {
+	return &DefaultAzureCredentialHolder{}
+}
+
+func (dac *DefaultAzureCredentialHolder) GetBlobClient(urlTemplate string) (*azblob.Client, error) {
+
+	serviceURL := fmt.Sprintf(urlTemplate, dac.Account, "")
+	// Create a default request pipeline using your storage account name and account key.
+	cred, err := azidentity.NewDefaultAzureCredential(nil)
+	if err != nil {
+		return nil, err
+	}
+
+	client, err := azblob.NewClient(serviceURL, cred, nil)
+	return client, err
 }

+ 19 - 18
pkg/cloud/azure/storagebillingparser.go

@@ -9,7 +9,8 @@ import (
 	"strings"
 	"time"
 
-	"github.com/Azure/azure-storage-blob-go/azblob"
+	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
+	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
 	"github.com/opencost/opencost/pkg/cloud"
 	"github.com/opencost/opencost/pkg/log"
 )
@@ -36,13 +37,13 @@ func (asbp *AzureStorageBillingParser) ParseBillingData(start, end time.Time, re
 		return err
 	}
 
-	containerURL, err := asbp.getContainer()
+	client, err := asbp.Authorizer.GetBlobClient(asbp.StorageConnection.getBlobURLTemplate())
 	if err != nil {
 		asbp.ConnectionStatus = cloud.FailedConnection
 		return err
 	}
 	ctx := context.Background()
-	blobNames, err := asbp.getMostRecentBlobs(start, end, containerURL, ctx)
+	blobNames, err := asbp.getMostRecentBlobs(start, end, client, ctx)
 	if err != nil {
 		asbp.ConnectionStatus = cloud.FailedConnection
 		return err
@@ -54,7 +55,7 @@ func (asbp *AzureStorageBillingParser) ParseBillingData(start, end time.Time, re
 	}
 
 	for _, blobName := range blobNames {
-		blobBytes, err2 := asbp.DownloadBlob(blobName, containerURL, ctx)
+		blobBytes, err2 := asbp.DownloadBlob(blobName, client, ctx)
 		if err2 != nil {
 			asbp.ConnectionStatus = cloud.FailedConnection
 			return err2
@@ -101,7 +102,7 @@ func (asbp *AzureStorageBillingParser) parseCSV(start, end time.Time, reader *cs
 	return nil
 }
 
-func (asbp *AzureStorageBillingParser) getMostRecentBlobs(start, end time.Time, containerURL *azblob.ContainerURL, ctx context.Context) ([]string, error) {
+func (asbp *AzureStorageBillingParser) getMostRecentBlobs(start, end time.Time, client *azblob.Client, ctx context.Context) ([]string, error) {
 	log.Infof("Azure Storage: retrieving most recent reports from: %v - %v", start, end)
 
 	// Get list of month substrings for months contained in the start to end range
@@ -109,24 +110,24 @@ func (asbp *AzureStorageBillingParser) getMostRecentBlobs(start, end time.Time,
 	if err != nil {
 		return nil, err
 	}
-	mostResentBlobs := make(map[string]azblob.BlobItemInternal)
-	for marker := (azblob.Marker{}); marker.NotDone(); {
-		// Get a result segment starting with the blob indicated by the current Marker.
-		listBlob, err := containerURL.ListBlobsFlatSegment(ctx, marker, azblob.ListBlobsSegmentOptions{})
+	mostResentBlobs := make(map[string]container.BlobItem)
+
+	pager := client.NewListBlobsFlatPager(asbp.Container, &azblob.ListBlobsFlatOptions{
+		Include: container.ListBlobsInclude{Deleted: false, Versions: false},
+	})
+
+	for pager.More() {
+		resp, err := pager.NextPage(ctx)
 		if err != nil {
 			return nil, err
 		}
 
-		// ListBlobs returns the start of the next segment; you MUST use this to get
-		// the next segment (after processing the current result segment).
-		marker = listBlob.NextMarker
-
 		// Using the list of months strings find the most resent blob for each month in the range
-		for _, blobInfo := range listBlob.Segment.BlobItems {
+		for _, blobInfo := range resp.Segment.BlobItems {
 			for _, month := range monthStrs {
-				if strings.Contains(blobInfo.Name, month) {
+				if strings.Contains(*blobInfo.Name, month) {
 					// If Container Path configuration exists, check if it is in the blobs name
-					if asbp.Path != "" && !strings.Contains(blobInfo.Name, asbp.Path) {
+					if asbp.Path != "" && !strings.Contains(*blobInfo.Name, asbp.Path) {
 						continue
 					}
 
@@ -135,7 +136,7 @@ func (asbp *AzureStorageBillingParser) getMostRecentBlobs(start, end time.Time,
 							continue
 						}
 					}
-					mostResentBlobs[month] = blobInfo
+					mostResentBlobs[month] = *blobInfo
 				}
 			}
 		}
@@ -145,7 +146,7 @@ func (asbp *AzureStorageBillingParser) getMostRecentBlobs(start, end time.Time,
 	var blobNames []string
 	for _, month := range monthStrs {
 		if blob, ok := mostResentBlobs[month]; ok {
-			blobNames = append(blobNames, blob.Name)
+			blobNames = append(blobNames, *blob.Name)
 		}
 	}
 

+ 11 - 27
pkg/cloud/azure/storageconnection.go

@@ -3,11 +3,9 @@ package azure
 import (
 	"bytes"
 	"context"
-	"fmt"
-	"net/url"
 	"strings"
 
-	"github.com/Azure/azure-storage-blob-go/azblob"
+	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
 	"github.com/opencost/opencost/pkg/cloud"
 	"github.com/opencost/opencost/pkg/log"
 )
@@ -35,25 +33,6 @@ func (sc *StorageConnection) Equals(config cloud.Config) bool {
 	return sc.StorageConfiguration.Equals(&thatConfig.StorageConfiguration)
 }
 
-func (sc *StorageConnection) getContainer() (*azblob.ContainerURL, error) {
-
-	credential, err := sc.Authorizer.GetBlobCredentials()
-	if err != nil {
-		return nil, err
-	}
-
-	p := azblob.NewPipeline(credential, azblob.PipelineOptions{})
-
-	// From the Azure portal, get your storage account blob service URL endpoint.
-	URL, _ := url.Parse(
-		fmt.Sprintf(sc.getBlobURLTemplate(), sc.Account, sc.Container))
-
-	// Create a ContainerURL object that wraps the container URL and a request
-	// pipeline to make requests.
-	containerURL := azblob.NewContainerURL(*URL, p)
-	return &containerURL, nil
-}
-
 // getBlobURLTemplate returns the correct BlobUrl for whichever Cloud storage account is specified by the AzureCloud configuration
 // defaults to the Public Cloud template
 func (sc *StorageConnection) getBlobURLTemplate() string {
@@ -65,20 +44,25 @@ func (sc *StorageConnection) getBlobURLTemplate() string {
 	return "https://%s.blob.core.windows.net/%s"
 }
 
-func (sc *StorageConnection) DownloadBlob(blobName string, containerURL *azblob.ContainerURL, ctx context.Context) ([]byte, error) {
+func (sc *StorageConnection) DownloadBlob(blobName string, client *azblob.Client, ctx context.Context) ([]byte, error) {
 	log.Infof("Azure Storage: retrieving blob: %v", blobName)
 
-	blobURL := containerURL.NewBlobURL(blobName)
-	downloadResponse, err := blobURL.Download(ctx, 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false, azblob.ClientProvidedKeyOptions{})
+	downloadResponse, err := client.DownloadStream(ctx, sc.Container, blobName, nil)
 	if err != nil {
 		return nil, err
 	}
 	// NOTE: automatically retries are performed if the connection fails
-	bodyStream := downloadResponse.Body(azblob.RetryReaderOptions{MaxRetryRequests: 20})
+	retryReader := downloadResponse.NewRetryReader(ctx, &azblob.RetryReaderOptions{})
 
 	// read the body into a buffer
 	downloadedData := bytes.Buffer{}
-	_, err = downloadedData.ReadFrom(bodyStream)
+
+	_, err = downloadedData.ReadFrom(retryReader)
+	if err != nil {
+		return nil, err
+	}
+
+	err = retryReader.Close()
 	if err != nil {
 		return nil, err
 	}