Ver Fonte

fix 3213: athena integration column 'year' cannot be resolved (#3268)

Signed-off-by: Sparsh <sparsh.raj30@gmail.com>
Sparsh Raj há 8 meses atrás
pai
commit
1ff0d54ccd

+ 29 - 0
pkg/cloud/aws/athenaconfiguration.go

@@ -18,6 +18,7 @@ type AthenaConfiguration struct {
 	Workgroup  string     `json:"workgroup"`
 	Account    string     `json:"account"`
 	Authorizer Authorizer `json:"authorizer"`
+	CURVersion string     `json:"curVersion,omitempty"` // "1.0" or "2.0", defaults to "2.0" if not specified
 }
 
 func (ac *AthenaConfiguration) Validate() error {
@@ -53,6 +54,11 @@ func (ac *AthenaConfiguration) Validate() error {
 		return fmt.Errorf("AthenaConfiguration: missing account")
 	}
 
+	// Validate CURVersion if specified
+	if ac.CURVersion != "" && ac.CURVersion != "1.0" && ac.CURVersion != "2.0" {
+		return fmt.Errorf("AthenaConfiguration: invalid CURVersion '%s', must be '1.0' or '2.0'", ac.CURVersion)
+	}
+
 	return nil
 }
 
@@ -103,6 +109,10 @@ func (ac *AthenaConfiguration) Equals(config cloud.Config) bool {
 		return false
 	}
 
+	if ac.CURVersion != thatConfig.CURVersion {
+		return false
+	}
+
 	return true
 }
 
@@ -116,6 +126,7 @@ func (ac *AthenaConfiguration) Sanitize() cloud.Config {
 		Workgroup:  ac.Workgroup,
 		Account:    ac.Account,
 		Authorizer: ac.Authorizer.Sanitize().(Authorizer),
+		CURVersion: ac.CURVersion,
 	}
 }
 
@@ -190,6 +201,18 @@ func (ac *AthenaConfiguration) UnmarshalJSON(b []byte) error {
 	}
 	ac.Authorizer = authorizer
 
+	// Parse CURVersion if present (optional field)
+	if _, ok := fmap["curVersion"]; ok {
+		curVersion, err := cloud.GetInterfaceValue[string](fmap, "curVersion")
+		if err != nil {
+			return fmt.Errorf("AthenaConfiguration: UnmarshalJSON: %w", err)
+		}
+		ac.CURVersion = curVersion
+	} else {
+		// Default to 2.0 if not specified
+		ac.CURVersion = "2.0"
+	}
+
 	return nil
 }
 
@@ -220,6 +243,11 @@ func ConvertAwsAthenaInfoToConfig(aai AwsAthenaInfo) cloud.KeyedConfig {
 
 	var config cloud.KeyedConfig
 	if aai.AthenaTable != "" || aai.AthenaDatabase != "" {
+		// Use CURVersion from config if specified, otherwise default to 2.0
+		curVersion := aai.CURVersion
+		if curVersion == "" {
+			curVersion = "2.0"
+		}
 		config = &AthenaConfiguration{
 			Bucket:     aai.AthenaBucketName,
 			Region:     aai.AthenaRegion,
@@ -229,6 +257,7 @@ func ConvertAwsAthenaInfoToConfig(aai AwsAthenaInfo) cloud.KeyedConfig {
 			Workgroup:  aai.AthenaWorkgroup,
 			Account:    aai.AccountID,
 			Authorizer: authorizer,
+			CURVersion: curVersion,
 		}
 	} else {
 		config = &S3Configuration{

+ 159 - 1
pkg/cloud/aws/athenaconfiguration_test.go

@@ -161,6 +161,62 @@ func TestAthenaConfiguration_Validate(t *testing.T) {
 			},
 			expected: fmt.Errorf("AthenaConfiguration: missing account"),
 		},
+		"valid CUR version 1.0": {
+			config: AthenaConfiguration{
+				Bucket:     "bucket",
+				Region:     "region",
+				Database:   "database",
+				Catalog:    "catalog",
+				Table:      "table",
+				Workgroup:  "workgroup",
+				Account:    "account",
+				Authorizer: &ServiceAccount{},
+				CURVersion: "1.0",
+			},
+			expected: nil,
+		},
+		"valid CUR version 2.0": {
+			config: AthenaConfiguration{
+				Bucket:     "bucket",
+				Region:     "region",
+				Database:   "database",
+				Catalog:    "catalog",
+				Table:      "table",
+				Workgroup:  "workgroup",
+				Account:    "account",
+				Authorizer: &ServiceAccount{},
+				CURVersion: "2.0",
+			},
+			expected: nil,
+		},
+		"valid empty CUR version defaults to 2.0": {
+			config: AthenaConfiguration{
+				Bucket:     "bucket",
+				Region:     "region",
+				Database:   "database",
+				Catalog:    "catalog",
+				Table:      "table",
+				Workgroup:  "workgroup",
+				Account:    "account",
+				Authorizer: &ServiceAccount{},
+				CURVersion: "",
+			},
+			expected: nil,
+		},
+		"invalid CUR version": {
+			config: AthenaConfiguration{
+				Bucket:     "bucket",
+				Region:     "region",
+				Database:   "database",
+				Catalog:    "catalog",
+				Table:      "table",
+				Workgroup:  "workgroup",
+				Account:    "account",
+				Authorizer: &ServiceAccount{},
+				CURVersion: "3.0",
+			},
+			expected: fmt.Errorf("AthenaConfiguration: invalid CURVersion '3.0', must be '1.0' or '2.0'"),
+		},
 	}
 
 	for name, testCase := range testCases {
@@ -515,6 +571,68 @@ func TestAthenaConfiguration_Equals(t *testing.T) {
 			},
 			expected: false,
 		},
+		"different CUR version": {
+			left: AthenaConfiguration{
+				Bucket:    "bucket",
+				Region:    "region",
+				Database:  "database",
+				Catalog:   "catalog",
+				Table:     "table",
+				Workgroup: "workgroup",
+				Account:   "account",
+				Authorizer: &AccessKey{
+					ID:     "id",
+					Secret: "secret",
+				},
+				CURVersion: "1.0",
+			},
+			right: &AthenaConfiguration{
+				Bucket:    "bucket",
+				Region:    "region",
+				Database:  "database",
+				Catalog:   "catalog",
+				Table:     "table",
+				Workgroup: "workgroup",
+				Account:   "account",
+				Authorizer: &AccessKey{
+					ID:     "id",
+					Secret: "secret",
+				},
+				CURVersion: "2.0",
+			},
+			expected: false,
+		},
+		"matching CUR version": {
+			left: AthenaConfiguration{
+				Bucket:    "bucket",
+				Region:    "region",
+				Database:  "database",
+				Catalog:   "catalog",
+				Table:     "table",
+				Workgroup: "workgroup",
+				Account:   "account",
+				Authorizer: &AccessKey{
+					ID:     "id",
+					Secret: "secret",
+				},
+				CURVersion: "1.0",
+			},
+			right: &AthenaConfiguration{
+				Bucket:    "bucket",
+				Region:    "region",
+				Database:  "database",
+				Catalog:   "catalog",
+				Table:     "table",
+				Workgroup: "workgroup",
+				Account:   "account",
+				Authorizer: &AccessKey{
+					ID:     "id",
+					Secret: "secret",
+				},
+				CURVersion: "1.0",
+			},
+			expected: true,
+		},
 		"different config": {
 			left: AthenaConfiguration{
 				Bucket:    "bucket",
@@ -551,7 +669,9 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 		config AthenaConfiguration
 	}{
 		"Empty Config": {
-			config: AthenaConfiguration{},
+			config: AthenaConfiguration{
+				CURVersion: "2.0", // Default value after JSON unmarshal
+			},
 		},
 		"AccessKey": {
 			config: AthenaConfiguration{
@@ -566,6 +686,7 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 					ID:     "id",
 					Secret: "secret",
 				},
+				CURVersion: "2.0", // Default value after JSON unmarshal
 			},
 		},
 
@@ -579,6 +700,7 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 				Workgroup:  "workgroup",
 				Account:    "account",
 				Authorizer: &ServiceAccount{},
+				CURVersion: "2.0", // Default value after JSON unmarshal
 			},
 		},
 		"AssumeRole with AccessKey": {
@@ -597,6 +719,7 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 					},
 					RoleARN: "12345",
 				},
+				CURVersion: "2.0", // Default value after JSON unmarshal
 			},
 		},
 		"AssumeRole with ServiceAccount": {
@@ -612,6 +735,7 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 					Authorizer: &ServiceAccount{},
 					RoleARN:    "12345",
 				},
+				CURVersion: "2.0", // Default value after JSON unmarshal
 			},
 		},
 		"RoleArnNil": {
@@ -627,6 +751,7 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 					Authorizer: nil,
 					RoleARN:    "12345",
 				},
+				CURVersion: "2.0", // Default value after JSON unmarshal
 			},
 		},
 		"AssumeRole with AssumeRole with ServiceAccount": {
@@ -645,6 +770,39 @@ func TestAthenaConfiguration_JSON(t *testing.T) {
 					},
 					RoleARN: "12345",
 				},
+				CURVersion: "2.0", // Default value after JSON unmarshal
+			},
+		},
+		"CUR Version 1.0": {
+			config: AthenaConfiguration{
+				Bucket:    "bucket",
+				Region:    "region",
+				Database:  "database",
+				Catalog:   "catalog",
+				Table:     "table",
+				Workgroup: "workgroup",
+				Account:   "account",
+				Authorizer: &AccessKey{
+					ID:     "id",
+					Secret: "secret",
+				},
+				CURVersion: "1.0",
+			},
+		},
+		"CUR Version 2.0": {
+			config: AthenaConfiguration{
+				Bucket:    "bucket",
+				Region:    "region",
+				Database:  "database",
+				Catalog:   "catalog",
+				Table:     "table",
+				Workgroup: "workgroup",
+				Account:   "account",
+				Authorizer: &AccessKey{
+					ID:     "id",
+					Secret: "secret",
+				},
+				CURVersion: "2.0",
 			},
 		},
 	}

+ 21 - 1
pkg/cloud/aws/athenaintegration.go

@@ -328,8 +328,28 @@ func (ai *AthenaIntegration) GetPartitionWhere(start, end time.Time) string {
 	month := time.Date(start.Year(), start.Month(), 1, 0, 0, 0, 0, time.UTC)
 	endMonth := time.Date(end.Year(), end.Month(), 1, 0, 0, 0, 0, time.UTC)
 	var disjuncts []string
+	
+	// For CUR 2.0, check if billing_period partitions actually exist
+	useBillingPeriodPartitions := false
+	if ai.CURVersion != "1.0" {
+		// Check if billing_period partitions exist in the table
+		if hasBillingPeriod, err := ai.HasBillingPeriodPartitions(); err == nil && hasBillingPeriod {
+			useBillingPeriodPartitions = true
+		}
+	}
+	
 	for !month.After(endMonth) {
-		disjuncts = append(disjuncts, fmt.Sprintf("(year = '%d' AND month = '%d')", month.Year(), month.Month()))
+		if ai.CURVersion == "1.0" {
+			// CUR 1.0 uses year and month columns for partitioning
+			disjuncts = append(disjuncts, fmt.Sprintf("(year = '%d' AND month = '%d')", month.Year(), month.Month()))
+		} else if useBillingPeriodPartitions {
+			// CUR 2.0 with billing_period partitions
+			disjuncts = append(disjuncts, fmt.Sprintf("(billing_period = '%d-%02d')", month.Year(), month.Month()))
+		} else {
+			// CUR 2.0 fallback - use date_format functions (less efficient but works without partitions)
+			disjuncts = append(disjuncts, fmt.Sprintf("(date_format(line_item_usage_start_date, '%%Y') = '%d' AND date_format(line_item_usage_start_date, '%%m') = '%02d')",
+				month.Year(), month.Month()))
+		}
 		month = month.AddDate(0, 1, 0)
 	}
 	str := fmt.Sprintf("(%s)", strings.Join(disjuncts, " OR "))

+ 232 - 0
pkg/cloud/aws/athenaintegration_test.go

@@ -1,8 +1,10 @@
 package aws
 
 import (
+	"fmt"
 	"os"
 	"reflect"
+	"strings"
 	"testing"
 	"time"
 
@@ -396,3 +398,233 @@ func stringsToRow(strings []string) types.Row {
 	}
 	return types.Row{Data: data}
 }
+
+// mockAthenaQuerier is a mock that overrides HasBillingPeriodPartitions for testing
+type mockAthenaQuerier struct {
+	AthenaQuerier
+	hasBillingPeriodPartitions bool
+}
+
+func (m *mockAthenaQuerier) HasBillingPeriodPartitions() (bool, error) {
+	return m.hasBillingPeriodPartitions, nil
+}
+
+// mockAthenaIntegration is a mock that uses mockAthenaQuerier
+type mockAthenaIntegration struct {
+	*mockAthenaQuerier
+}
+
+func (m *mockAthenaIntegration) GetPartitionWhere(start, end time.Time) string {
+	// The partition logic using our mock's HasBillingPeriodPartitions result
+	month := time.Date(start.Year(), start.Month(), 1, 0, 0, 0, 0, time.UTC)
+	endMonth := time.Date(end.Year(), end.Month(), 1, 0, 0, 0, 0, time.UTC)
+	var disjuncts []string
+	
+	// Using our mock's result for billing period partitions
+	useBillingPeriodPartitions := false
+	if m.mockAthenaQuerier.AthenaConfiguration.CURVersion != "1.0" {
+		useBillingPeriodPartitions = m.mockAthenaQuerier.hasBillingPeriodPartitions
+	}
+	
+	for !month.After(endMonth) {
+		if m.mockAthenaQuerier.AthenaConfiguration.CURVersion == "1.0" {
+			// CUR 1.0 uses year and month columns for partitioning
+			disjuncts = append(disjuncts, fmt.Sprintf("(year = '%d' AND month = '%d')", month.Year(), month.Month()))
+		} else if useBillingPeriodPartitions {
+			// CUR 2.0 with billing_period partitions
+			disjuncts = append(disjuncts, fmt.Sprintf("(billing_period = '%d-%02d')", month.Year(), month.Month()))
+		} else {
+			// CUR 2.0 fallback - use date_format functions
+			disjuncts = append(disjuncts, fmt.Sprintf("(date_format(line_item_usage_start_date, '%%Y') = '%d' AND date_format(line_item_usage_start_date, '%%m') = '%02d')",
+				month.Year(), month.Month()))
+		}
+		month = month.AddDate(0, 1, 0)
+	}
+	return fmt.Sprintf("(%s)", strings.Join(disjuncts, " OR "))
+}
+
+func TestAthenaIntegration_GetPartitionWhere(t *testing.T) {
+	testCases := map[string]struct {
+		integration interface{ GetPartitionWhere(time.Time, time.Time) string }
+		start       time.Time
+		end         time.Time
+		expected    string
+	}{
+		"CUR 1.0 single month": {
+			integration: &AthenaIntegration{
+				AthenaQuerier: AthenaQuerier{
+					AthenaConfiguration: AthenaConfiguration{
+						Bucket:     "bucket",
+						Region:     "region",
+						Database:   "database",
+						Table:      "table",
+						Workgroup:  "workgroup",
+						Account:    "account",
+						Authorizer: &ServiceAccount{},
+						CURVersion: "1.0",
+					},
+				},
+			},
+			start:    time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 1, 25, 0, 0, 0, 0, time.UTC),
+			expected: "((year = '2024' AND month = '1'))",
+		},
+		"CUR 2.0 single month": {
+			integration: &mockAthenaIntegration{
+				mockAthenaQuerier: &mockAthenaQuerier{
+					AthenaQuerier: AthenaQuerier{
+						AthenaConfiguration: AthenaConfiguration{
+							Bucket:     "bucket",
+							Region:     "region",
+							Database:   "database",
+							Table:      "table",
+							Workgroup:  "workgroup",
+							Account:    "account",
+							Authorizer: &ServiceAccount{},
+							CURVersion: "2.0",
+						},
+					},
+					hasBillingPeriodPartitions: true,
+				},
+			},
+			start:    time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 1, 25, 0, 0, 0, 0, time.UTC),
+			expected: "((billing_period = '2024-01'))",
+		},
+		"CUR 1.0 multiple months": {
+			integration: &AthenaIntegration{
+				AthenaQuerier: AthenaQuerier{
+					AthenaConfiguration: AthenaConfiguration{
+						Bucket:     "bucket",
+						Region:     "region",
+						Database:   "database",
+						Table:      "table",
+						Workgroup:  "workgroup",
+						Account:    "account",
+						Authorizer: &ServiceAccount{},
+						CURVersion: "1.0",
+					},
+				},
+			},
+			start:    time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 3, 10, 0, 0, 0, 0, time.UTC),
+			expected: "((year = '2024' AND month = '1') OR (year = '2024' AND month = '2') OR (year = '2024' AND month = '3'))",
+		},
+		"CUR 2.0 multiple months": {
+			integration: &mockAthenaIntegration{
+				mockAthenaQuerier: &mockAthenaQuerier{
+					AthenaQuerier: AthenaQuerier{
+						AthenaConfiguration: AthenaConfiguration{
+							Bucket:     "bucket",
+							Region:     "region",
+							Database:   "database",
+							Table:      "table",
+							Workgroup:  "workgroup",
+							Account:    "account",
+							Authorizer: &ServiceAccount{},
+							CURVersion: "2.0",
+						},
+					},
+					hasBillingPeriodPartitions: true,
+				},
+			},
+			start:    time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 3, 10, 0, 0, 0, 0, time.UTC),
+			expected: "((billing_period = '2024-01') OR (billing_period = '2024-02') OR (billing_period = '2024-03'))",
+		},
+		"CUR 2.0 across year boundary": {
+			integration: &mockAthenaIntegration{
+				mockAthenaQuerier: &mockAthenaQuerier{
+					AthenaQuerier: AthenaQuerier{
+						AthenaConfiguration: AthenaConfiguration{
+							Bucket:     "bucket",
+							Region:     "region",
+							Database:   "database",
+							Table:      "table",
+							Workgroup:  "workgroup",
+							Account:    "account",
+							Authorizer: &ServiceAccount{},
+							CURVersion: "2.0",
+						},
+					},
+					hasBillingPeriodPartitions: true,
+				},
+			},
+			start:    time.Date(2023, 12, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC),
+			expected: "((billing_period = '2023-12') OR (billing_period = '2024-01') OR (billing_period = '2024-02'))",
+		},
+		"CUR 1.0 across year boundary": {
+			integration: &AthenaIntegration{
+				AthenaQuerier: AthenaQuerier{
+					AthenaConfiguration: AthenaConfiguration{
+						Bucket:     "bucket",
+						Region:     "region",
+						Database:   "database",
+						Table:      "table",
+						Workgroup:  "workgroup",
+						Account:    "account",
+						Authorizer: &ServiceAccount{},
+						CURVersion: "1.0",
+					},
+				},
+			},
+			start:    time.Date(2023, 12, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC),
+			expected: "((year = '2023' AND month = '12') OR (year = '2024' AND month = '1') OR (year = '2024' AND month = '2'))",
+		},
+		"Default CUR version (empty string defaults to 2.0)": {
+			integration: &mockAthenaIntegration{
+				mockAthenaQuerier: &mockAthenaQuerier{
+					AthenaQuerier: AthenaQuerier{
+						AthenaConfiguration: AthenaConfiguration{
+							Bucket:     "bucket",
+							Region:     "region",
+							Database:   "database",
+							Table:      "table",
+							Workgroup:  "workgroup",
+							Account:    "account",
+							Authorizer: &ServiceAccount{},
+							CURVersion: "",
+						},
+					},
+					hasBillingPeriodPartitions: true,
+				},
+			},
+			start:    time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 1, 25, 0, 0, 0, 0, time.UTC),
+			expected: "((billing_period = '2024-01'))",
+		},
+		"CUR 2.0 fallback when no billing_period partitions": {
+			integration: &mockAthenaIntegration{
+				mockAthenaQuerier: &mockAthenaQuerier{
+					AthenaQuerier: AthenaQuerier{
+						AthenaConfiguration: AthenaConfiguration{
+							Bucket:     "bucket",
+							Region:     "region",
+							Database:   "database",
+							Table:      "table",
+							Workgroup:  "workgroup",
+							Account:    "account",
+							Authorizer: &ServiceAccount{},
+							CURVersion: "2.0",
+						},
+					},
+					hasBillingPeriodPartitions: false, // No billing_period partitions
+				},
+			},
+			start:    time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC),
+			end:      time.Date(2024, 1, 25, 0, 0, 0, 0, time.UTC),
+			expected: "((date_format(line_item_usage_start_date, '%Y') = '2024' AND date_format(line_item_usage_start_date, '%m') = '01'))",
+		},
+	}
+
+	for name, testCase := range testCases {
+		t.Run(name, func(t *testing.T) {
+			actual := testCase.integration.GetPartitionWhere(testCase.start, testCase.end)
+			if actual != testCase.expected {
+				t.Errorf("GetPartitionWhere() mismatch:\nActual:   %s\nExpected: %s", actual, testCase.expected)
+			}
+		})
+	}
+}

+ 26 - 0
pkg/cloud/aws/athenaquerier.go

@@ -63,6 +63,32 @@ func (aq *AthenaQuerier) GetColumns() (map[string]bool, error) {
 	return columnSet, nil
 }
 
+// HasBillingPeriodPartitions checks if the table uses billing_period partitioning
+// by querying SHOW PARTITIONS and looking for billing_period partition keys
+func (aq *AthenaQuerier) HasBillingPeriodPartitions() (bool, error) {
+	// Use SHOW PARTITIONS to check if billing_period partitions exist
+	query := fmt.Sprintf("SHOW PARTITIONS \"%s\"", aq.Table)
+	hasBillingPeriodPartition := false
+	
+	athenaErr := aq.Query(context.TODO(), query, GetAthenaQueryFunc(func(row types.Row) {
+		if len(row.Data) > 0 && row.Data[0].VarCharValue != nil {
+			partitionValue := *row.Data[0].VarCharValue
+			// Check if partition follows billing_period=YYYY-MM format
+			if strings.HasPrefix(partitionValue, "billing_period=") {
+				hasBillingPeriodPartition = true
+			}
+		}
+	}))
+
+	if athenaErr != nil {
+		// If SHOW PARTITIONS fails, assume no billing_period partitions
+		log.Debugf("AthenaQuerier[%s]: SHOW PARTITIONS failed: %s", aq.Key(), athenaErr.Error())
+		return false, athenaErr
+	}
+
+	return hasBillingPeriodPartition, nil
+}
+
 func (aq *AthenaQuerier) Query(ctx context.Context, query string, fn func(*athena.GetQueryResultsOutput) bool) error {
 	err := aq.Validate()
 	if err != nil {

+ 5 - 0
pkg/cloud/aws/provider.go

@@ -386,6 +386,7 @@ type AwsAthenaInfo struct {
 	ServiceKeySecret string `json:"serviceKeySecret"`
 	AccountID        string `json:"projectID"`
 	MasterPayerARN   string `json:"masterPayerARN"`
+	CURVersion       string `json:"curVersion"` // "1.0" or "2.0", defaults to "2.0" if not specified
 }
 
 // IsEmpty returns true if all fields in config are empty, false if not.
@@ -501,6 +502,7 @@ func (aws *AWS) GetAWSAthenaInfo() (*AwsAthenaInfo, error) {
 		ServiceKeySecret: aak.SecretAccessKey,
 		AccountID:        config.AthenaProjectID,
 		MasterPayerARN:   config.MasterPayerARN,
+		CURVersion:       config.AthenaCURVersion,
 	}, nil
 }
 
@@ -561,6 +563,9 @@ func (aws *AWS) UpdateConfig(r io.Reader, updateType string) (*models.CustomPric
 				c.MasterPayerARN = aai.MasterPayerARN
 			}
 			c.AthenaProjectID = aai.AccountID
+			if aai.CURVersion != "" {
+				c.AthenaCURVersion = aai.CURVersion
+			}
 		} else {
 			a := make(map[string]interface{})
 			err := json.NewDecoder(r).Decode(&a)

+ 2 - 0
pkg/cloud/config/configurations_test.go

@@ -116,6 +116,7 @@ var (
 						ID:     "id",
 						Secret: "secret",
 					},
+					CURVersion: "2.0",
 				},
 			},
 		},
@@ -149,6 +150,7 @@ var (
 						Authorizer: &aws.ServiceAccount{},
 						RoleARN:    "roleArn",
 					},
+					CURVersion: "2.0",
 				},
 			},
 		},

+ 3 - 0
pkg/cloud/config/controller_test.go

@@ -22,6 +22,7 @@ var validAthenaConf = &aws.AthenaConfiguration{
 	Workgroup:  "workgroup",
 	Account:    "account",
 	Authorizer: &aws.ServiceAccount{},
+	CURVersion: "2.0",
 }
 
 // Config with the same key as the baseline but is not equal to it because of the change in the non-keyed property Workgroup
@@ -33,6 +34,7 @@ var validAthenaConfModifiedProperty = &aws.AthenaConfiguration{
 	Workgroup:  "workgroup1",
 	Account:    "account",
 	Authorizer: &aws.ServiceAccount{},
+	CURVersion: "2.0",
 }
 
 // Config with the same key as baseline but is invalid due to missing Authorizer
@@ -44,6 +46,7 @@ var invalidAthenaConf = &aws.AthenaConfiguration{
 	Workgroup:  "workgroup",
 	Account:    "account",
 	Authorizer: nil,
+	CURVersion: "2.0",
 }
 
 // A valid config with a different key from the baseline

+ 1 - 0
pkg/cloud/models/models.go

@@ -163,6 +163,7 @@ type CustomPricing struct {
 	AthenaTable                  string `json:"athenaTable"`
 	AthenaWorkgroup              string `json:"athenaWorkgroup"`
 	MasterPayerARN               string `json:"masterPayerARN"`
+	AthenaCURVersion             string `json:"athenaCURVersion,omitempty"` // "1.0" or "2.0", defaults to "2.0"
 	BillingDataDataset           string `json:"billingDataDataset,omitempty"`
 	CustomPricesEnabled          string `json:"customPricesEnabled"`
 	DefaultIdle                  string `json:"defaultIdle"`

+ 32 - 0
pkg/cloud/models/models_test.go

@@ -116,3 +116,35 @@ func TestSetSetCustomPricingField(t *testing.T) {
 		})
 	}
 }
+
+func TestCustomPricing_AthenaCURVersion(t *testing.T) {
+	testCases := map[string]struct {
+		curVersion string
+		expected   string
+	}{
+		"CUR version 1.0": {
+			curVersion: "1.0",
+			expected:   "1.0",
+		},
+		"CUR version 2.0": {
+			curVersion: "2.0",
+			expected:   "2.0",
+		},
+		"empty CUR version": {
+			curVersion: "",
+			expected:   "",
+		},
+	}
+
+	for name, testCase := range testCases {
+		t.Run(name, func(t *testing.T) {
+			cp := &CustomPricing{
+				AthenaCURVersion: testCase.curVersion,
+			}
+			
+			if cp.AthenaCURVersion != testCase.expected {
+				t.Errorf("expected AthenaCURVersion to be '%s', got '%s'", testCase.expected, cp.AthenaCURVersion)
+			}
+		})
+	}
+}