|
|
@@ -0,0 +1,956 @@
|
|
|
+package aws
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "errors"
|
|
|
+ "testing"
|
|
|
+
|
|
|
+ "github.com/aws/aws-sdk-go-v2/aws"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/service/athena"
|
|
|
+ "github.com/aws/aws-sdk-go-v2/service/athena/types"
|
|
|
+ "github.com/opencost/opencost/pkg/cloud"
|
|
|
+)
|
|
|
+
|
|
|
+func TestAthenaQuerier_GetColumns(t *testing.T) {
|
|
|
+ // Create mock client
|
|
|
+ mockClient := &MockAthenaClient{}
|
|
|
+
|
|
|
+ // Create mock querier with valid configuration
|
|
|
+ querier := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test successful column retrieval
|
|
|
+ t.Run("successful_column_retrieval", func(t *testing.T) {
|
|
|
+ // Mock successful query results with column names
|
|
|
+ // First row is header, subsequent rows are data
|
|
|
+ mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
|
|
|
+ return &athena.GetQueryResultsOutput{
|
|
|
+ ResultSet: &types.ResultSet{
|
|
|
+ Rows: []types.Row{
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("column_name")}}}, // Header row
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("column1")}}},
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("column2")}}},
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("column3")}}},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ columns, err := querier.GetColumns()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetColumns() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedColumns := map[string]bool{
|
|
|
+ "column1": true,
|
|
|
+ "column2": true,
|
|
|
+ "column3": true,
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(columns) != len(expectedColumns) {
|
|
|
+ t.Errorf("GetColumns() returned %d columns, want %d", len(columns), len(expectedColumns))
|
|
|
+ }
|
|
|
+
|
|
|
+ for col := range expectedColumns {
|
|
|
+ if !columns[col] {
|
|
|
+ t.Errorf("GetColumns() missing expected column: %s", col)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test empty results
|
|
|
+ t.Run("empty_results", func(t *testing.T) {
|
|
|
+ mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
|
|
|
+ return &athena.GetQueryResultsOutput{
|
|
|
+ ResultSet: &types.ResultSet{
|
|
|
+ Rows: []types.Row{
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("column_name")}}}, // Header row only
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ columns, err := querier.GetColumns()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetColumns() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(columns) != 0 {
|
|
|
+ t.Errorf("GetColumns() returned %d columns, want 0", len(columns))
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query error
|
|
|
+ t.Run("query_error", func(t *testing.T) {
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ return nil, errors.New("query execution failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ columns, err := querier.GetColumns()
|
|
|
+ if err == nil {
|
|
|
+ t.Error("GetColumns() should return error when query fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(columns) != 0 {
|
|
|
+ t.Errorf("GetColumns() should return empty map on error, got %d columns", len(columns))
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestAthenaQuerier_Query(t *testing.T) {
|
|
|
+ // Create mock client
|
|
|
+ mockClient := &MockAthenaClient{}
|
|
|
+
|
|
|
+ // Create mock querier with valid configuration
|
|
|
+ querier := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test successful query
|
|
|
+ t.Run("successful_query", func(t *testing.T) {
|
|
|
+ queryExecuted := false
|
|
|
+ queryString := "SELECT * FROM test_table"
|
|
|
+
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ if *params.QueryString != queryString {
|
|
|
+ t.Errorf("Expected query string %s, got %s", queryString, *params.QueryString)
|
|
|
+ }
|
|
|
+ queryExecuted = true
|
|
|
+ return &athena.StartQueryExecutionOutput{
|
|
|
+ QueryExecutionId: aws.String("test-query-id"),
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
|
|
|
+ return &athena.GetQueryResultsOutput{
|
|
|
+ ResultSet: &types.ResultSet{
|
|
|
+ Rows: []types.Row{
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("header")}}}, // Header row
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("test-data")}}},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ rowsProcessed := 0
|
|
|
+ queryFunc := GetAthenaQueryFunc(func(row types.Row) {
|
|
|
+ rowsProcessed++
|
|
|
+ })
|
|
|
+
|
|
|
+ err := querier.Query(context.Background(), queryString, queryFunc)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("Query() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !queryExecuted {
|
|
|
+ t.Error("Query execution was not called")
|
|
|
+ }
|
|
|
+
|
|
|
+ if rowsProcessed != 1 {
|
|
|
+ t.Errorf("Expected 1 row to be processed, got %d", rowsProcessed)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check connection status is successful
|
|
|
+ if querier.ConnectionStatus != cloud.SuccessfulConnection {
|
|
|
+ t.Errorf("Expected connection status to be SuccessfulConnection, got %s", querier.ConnectionStatus)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test invalid configuration
|
|
|
+ t.Run("invalid_configuration", func(t *testing.T) {
|
|
|
+ invalidQuerier := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ // Missing required fields
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ err := invalidQuerier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
|
|
|
+ if err == nil {
|
|
|
+ t.Error("Query() should return error for invalid configuration")
|
|
|
+ }
|
|
|
+
|
|
|
+ if invalidQuerier.ConnectionStatus != cloud.InvalidConfiguration {
|
|
|
+ t.Errorf("Expected connection status to be InvalidConfiguration, got %s", invalidQuerier.ConnectionStatus)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query execution failure
|
|
|
+ t.Run("query_execution_failure", func(t *testing.T) {
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ return nil, errors.New("query execution failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
|
|
|
+ if err == nil {
|
|
|
+ t.Error("Query() should return error when query execution fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ if querier.ConnectionStatus != cloud.FailedConnection {
|
|
|
+ t.Errorf("Expected connection status to be FailedConnection, got %s", querier.ConnectionStatus)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query waiting failure
|
|
|
+ t.Run("query_waiting_failure", func(t *testing.T) {
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ return &athena.StartQueryExecutionOutput{
|
|
|
+ QueryExecutionId: aws.String("test-query-id"),
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: types.QueryExecutionStateFailed,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
|
|
|
+ if err == nil {
|
|
|
+ t.Error("Query() should return error when query waiting fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ if querier.ConnectionStatus != cloud.FailedConnection {
|
|
|
+ t.Errorf("Expected connection status to be FailedConnection, got %s", querier.ConnectionStatus)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestAthenaQuerier_GetAthenaClient(t *testing.T) {
|
|
|
+ // Test successful client creation
|
|
|
+ t.Run("successful_client_creation", func(t *testing.T) {
|
|
|
+ querier := &AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ client, err := querier.GetAthenaClient()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetAthenaClient() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ t.Error("GetAthenaClient() returned nil client")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test client creation with service account authorizer
|
|
|
+ t.Run("service_account_authorizer", func(t *testing.T) {
|
|
|
+ querier := &AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &ServiceAccount{},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ client, err := querier.GetAthenaClient()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetAthenaClient() with ServiceAccount returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ t.Error("GetAthenaClient() returned nil client")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test client creation with assume role authorizer
|
|
|
+ t.Run("assume_role_authorizer", func(t *testing.T) {
|
|
|
+ querier := &AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AssumeRole{
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ RoleARN: "arn:aws:iam::123456789012:role/test-role",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ client, err := querier.GetAthenaClient()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetAthenaClient() with AssumeRole returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ t.Error("GetAthenaClient() returned nil client")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test client creation failure with invalid authorizer
|
|
|
+ t.Run("invalid_authorizer", func(t *testing.T) {
|
|
|
+ querier := &AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "", Secret: ""}, // Invalid credentials
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ client, err := querier.GetAthenaClient()
|
|
|
+ if err == nil {
|
|
|
+ t.Error("GetAthenaClient() should return error for invalid authorizer")
|
|
|
+ }
|
|
|
+
|
|
|
+ if client != nil {
|
|
|
+ t.Error("GetAthenaClient() should return nil client on error")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test client creation with different regions
|
|
|
+ t.Run("different_regions", func(t *testing.T) {
|
|
|
+ regions := []string{"us-east-1", "us-west-2", "eu-west-1", "ap-southeast-1"}
|
|
|
+
|
|
|
+ for _, region := range regions {
|
|
|
+ querier := &AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: region,
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ client, err := querier.GetAthenaClient()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetAthenaClient() for region %s returned error: %v", region, err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if client == nil {
|
|
|
+ t.Errorf("GetAthenaClient() for region %s returned nil client", region)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestAthenaQuerier_queryAthenaPaginated(t *testing.T) {
|
|
|
+ // Create mock client
|
|
|
+ mockClient := &MockAthenaClient{}
|
|
|
+
|
|
|
+ // Create mock querier with valid configuration
|
|
|
+ querier := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test successful paginated query
|
|
|
+ t.Run("successful_paginated_query", func(t *testing.T) {
|
|
|
+ queryString := "SELECT * FROM test_table"
|
|
|
+ queryExecuted := false
|
|
|
+
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ if *params.QueryString != queryString {
|
|
|
+ t.Errorf("Expected query string %s, got %s", queryString, *params.QueryString)
|
|
|
+ }
|
|
|
+ if *params.QueryExecutionContext.Database != "test-db" {
|
|
|
+ t.Errorf("Expected database test-db, got %s", *params.QueryExecutionContext.Database)
|
|
|
+ }
|
|
|
+ if *params.ResultConfiguration.OutputLocation != "test-bucket" {
|
|
|
+ t.Errorf("Expected bucket test-bucket, got %s", *params.ResultConfiguration.OutputLocation)
|
|
|
+ }
|
|
|
+ queryExecuted = true
|
|
|
+ return &athena.StartQueryExecutionOutput{
|
|
|
+ QueryExecutionId: aws.String("test-query-id"),
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
|
|
|
+ return &athena.GetQueryResultsOutput{
|
|
|
+ ResultSet: &types.ResultSet{
|
|
|
+ Rows: []types.Row{
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("row1")}}},
|
|
|
+ {Data: []types.Datum{{VarCharValue: aws.String("row2")}}},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ rowsProcessed := 0
|
|
|
+ queryFunc := func(page *athena.GetQueryResultsOutput) bool {
|
|
|
+ for range page.ResultSet.Rows {
|
|
|
+ rowsProcessed++
|
|
|
+ }
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.queryAthenaPaginated(context.Background(), queryString, queryFunc)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("queryAthenaPaginated() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !queryExecuted {
|
|
|
+ t.Error("Query execution was not called")
|
|
|
+ }
|
|
|
+
|
|
|
+ if rowsProcessed != 2 {
|
|
|
+ t.Errorf("Expected 2 rows to be processed, got %d", rowsProcessed)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query with catalog
|
|
|
+ t.Run("query_with_catalog", func(t *testing.T) {
|
|
|
+ querierWithCatalog := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Catalog: "test-catalog",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ catalogSet := false
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ if params.QueryExecutionContext.Catalog != nil && *params.QueryExecutionContext.Catalog == "test-catalog" {
|
|
|
+ catalogSet = true
|
|
|
+ }
|
|
|
+ return &athena.StartQueryExecutionOutput{
|
|
|
+ QueryExecutionId: aws.String("test-query-id"),
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querierWithCatalog.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("queryAthenaPaginated() with catalog returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !catalogSet {
|
|
|
+ t.Error("Catalog was not set in query execution context")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query with workgroup
|
|
|
+ t.Run("query_with_workgroup", func(t *testing.T) {
|
|
|
+ querierWithWorkgroup := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Workgroup: "test-workgroup",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ workgroupSet := false
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ if params.WorkGroup != nil && *params.WorkGroup == "test-workgroup" {
|
|
|
+ workgroupSet = true
|
|
|
+ }
|
|
|
+ return &athena.StartQueryExecutionOutput{
|
|
|
+ QueryExecutionId: aws.String("test-query-id"),
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querierWithWorkgroup.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("queryAthenaPaginated() with workgroup returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !workgroupSet {
|
|
|
+ t.Error("Workgroup was not set in query execution input")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query execution failure
|
|
|
+ t.Run("query_execution_failure", func(t *testing.T) {
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ return nil, errors.New("query execution failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
|
|
|
+ if err == nil {
|
|
|
+ t.Error("queryAthenaPaginated() should return error when query execution fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "QueryAthenaPaginated: start query error: query execution failed"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query waiting failure
|
|
|
+ t.Run("query_waiting_failure", func(t *testing.T) {
|
|
|
+ mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
|
|
|
+ return &athena.StartQueryExecutionOutput{
|
|
|
+ QueryExecutionId: aws.String("test-query-id"),
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: types.QueryExecutionStateFailed,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
|
|
|
+ if err == nil {
|
|
|
+ t.Error("queryAthenaPaginated() should return error when query waiting fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "QueryAthenaPaginated: query execution error: no query results available for query test-query-id"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test get client failure
|
|
|
+ t.Run("get_client_failure", func(t *testing.T) {
|
|
|
+ invalidQuerier := &FailingQueryAthenaQuerier{
|
|
|
+ MockAthenaQuerier: MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ err := invalidQuerier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
|
|
|
+ if err == nil {
|
|
|
+ t.Error("queryAthenaPaginated() should return error when client creation fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "QueryAthenaPaginated: GetAthenaClient error: failed to create client"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestAthenaQuerier_waitForQueryToComplete(t *testing.T) {
|
|
|
+ // Create mock client
|
|
|
+ mockClient := &MockAthenaClient{}
|
|
|
+
|
|
|
+ // Create mock querier
|
|
|
+ querier := &MockAthenaQuerier{
|
|
|
+ AthenaQuerier: AthenaQuerier{
|
|
|
+ AthenaConfiguration: AthenaConfiguration{
|
|
|
+ Bucket: "test-bucket",
|
|
|
+ Region: "us-east-1",
|
|
|
+ Database: "test-db",
|
|
|
+ Table: "test-table",
|
|
|
+ Account: "123456789012",
|
|
|
+ Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ mockClient: mockClient,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test successful query completion
|
|
|
+ t.Run("successful_query_completion", func(t *testing.T) {
|
|
|
+ queryID := "test-query-id"
|
|
|
+ callCount := 0
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ callCount++
|
|
|
+ if *params.QueryExecutionId != queryID {
|
|
|
+ t.Errorf("Expected query ID %s, got %s", queryID, *params.QueryExecutionId)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Return SUCCEEDED on first call
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: types.QueryExecutionStateSucceeded,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("waitForQueryToComplete() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if callCount != 1 {
|
|
|
+ t.Errorf("Expected 1 call to GetQueryExecution, got %d", callCount)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query failure
|
|
|
+ t.Run("query_failure", func(t *testing.T) {
|
|
|
+ queryID := "test-query-id"
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: types.QueryExecutionStateFailed,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
|
|
|
+ if err == nil {
|
|
|
+ t.Error("waitForQueryToComplete() should return error when query fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "no query results available for query test-query-id"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query cancellation
|
|
|
+ t.Run("query_cancellation", func(t *testing.T) {
|
|
|
+ queryID := "test-query-id"
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: types.QueryExecutionStateCancelled,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
|
|
|
+ if err == nil {
|
|
|
+ t.Error("waitForQueryToComplete() should return error when query is cancelled")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "no query results available for query test-query-id"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test query timeout
|
|
|
+ t.Run("query_timeout", func(t *testing.T) {
|
|
|
+ queryID := "test-query-id"
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: "TIMED_OUT", // Use string literal since QueryExecutionStateTimedOut doesn't exist
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
|
|
|
+ if err == nil {
|
|
|
+ t.Error("waitForQueryToComplete() should return error when query times out")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "no query results available for query test-query-id"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test GetQueryExecution error
|
|
|
+ t.Run("get_query_execution_error", func(t *testing.T) {
|
|
|
+ queryID := "test-query-id"
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ return nil, errors.New("failed to get query execution")
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
|
|
|
+ if err == nil {
|
|
|
+ t.Error("waitForQueryToComplete() should return error when GetQueryExecution fails")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedError := "failed to get query execution"
|
|
|
+ if err.Error() != expectedError {
|
|
|
+ t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test context cancellation
|
|
|
+ t.Run("context_cancellation", func(t *testing.T) {
|
|
|
+ queryID := "test-query-id"
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ cancel() // Cancel immediately
|
|
|
+
|
|
|
+ mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
|
|
|
+ // Check if context is cancelled
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return nil, ctx.Err()
|
|
|
+ default:
|
|
|
+ return &athena.GetQueryExecutionOutput{
|
|
|
+ QueryExecution: &types.QueryExecution{
|
|
|
+ Status: &types.QueryExecutionStatus{
|
|
|
+ State: types.QueryExecutionStateSucceeded,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ err := querier.waitForQueryToComplete(ctx, mockClient, &queryID)
|
|
|
+ if err == nil {
|
|
|
+ t.Error("waitForQueryToComplete() should return error when context is cancelled")
|
|
|
+ }
|
|
|
+
|
|
|
+ if err != context.Canceled {
|
|
|
+ t.Errorf("Expected context.Canceled error, got %v", err)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // Test with nil query execution ID
|
|
|
+ t.Run("nil_query_execution_id", func(t *testing.T) {
|
|
|
+ err := querier.waitForQueryToComplete(context.Background(), mockClient, nil)
|
|
|
+ if err == nil {
|
|
|
+ t.Error("waitForQueryToComplete() should return error when query execution ID is nil")
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestAthenaQuerier_GetAthenaQueryFunc(t *testing.T) {
|
|
|
+ // Test that GetAthenaQueryFunc returns a function
|
|
|
+ queryFunc := GetAthenaQueryFunc(func(row types.Row) {
|
|
|
+ // Do nothing
|
|
|
+ })
|
|
|
+
|
|
|
+ if queryFunc == nil {
|
|
|
+ t.Error("GetAthenaQueryFunc should return a non-nil function")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test that the returned function can be called
|
|
|
+ result := &athena.GetQueryResultsOutput{
|
|
|
+ ResultSet: &types.ResultSet{
|
|
|
+ Rows: []types.Row{
|
|
|
+ {Data: []types.Datum{}},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ // Should not panic
|
|
|
+ queryFunc(result)
|
|
|
+}
|
|
|
+
|
|
|
+func TestGetAthenaRowValue(t *testing.T) {
|
|
|
+ // Test with valid data
|
|
|
+ row := types.Row{
|
|
|
+ Data: []types.Datum{
|
|
|
+ {VarCharValue: stringPtr("test-value")},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ queryColumnIndexes := map[string]int{
|
|
|
+ "test-column": 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ result := GetAthenaRowValue(row, queryColumnIndexes, "test-column")
|
|
|
+ if result != "test-value" {
|
|
|
+ t.Errorf("GetAthenaRowValue() = %v, want %v", result, "test-value")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test with missing column
|
|
|
+ result = GetAthenaRowValue(row, queryColumnIndexes, "missing-column")
|
|
|
+ if result != "" {
|
|
|
+ t.Errorf("GetAthenaRowValue() with missing column = %v, want %v", result, "")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test with nil value
|
|
|
+ rowWithNil := types.Row{
|
|
|
+ Data: []types.Datum{
|
|
|
+ {VarCharValue: nil},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ result = GetAthenaRowValue(rowWithNil, queryColumnIndexes, "test-column")
|
|
|
+ if result != "" {
|
|
|
+ t.Errorf("GetAthenaRowValue() with nil value = %v, want %v", result, "")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestGetAthenaRowValueFloat(t *testing.T) {
|
|
|
+ // Test with valid data
|
|
|
+ row := types.Row{
|
|
|
+ Data: []types.Datum{
|
|
|
+ {VarCharValue: stringPtr("3.14159")},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ queryColumnIndexes := map[string]int{
|
|
|
+ "test-column": 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ result, err := GetAthenaRowValueFloat(row, queryColumnIndexes, "test-column")
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("GetAthenaRowValueFloat() returned error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if result != 3.14159 {
|
|
|
+ t.Errorf("GetAthenaRowValueFloat() = %v, want %v", result, 3.14159)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test with missing column
|
|
|
+ _, err = GetAthenaRowValueFloat(row, queryColumnIndexes, "missing-column")
|
|
|
+ if err == nil {
|
|
|
+ t.Error("GetAthenaRowValueFloat() should return error for missing column")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test with nil value
|
|
|
+ rowWithNil := types.Row{
|
|
|
+ Data: []types.Datum{
|
|
|
+ {VarCharValue: nil},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = GetAthenaRowValueFloat(rowWithNil, queryColumnIndexes, "test-column")
|
|
|
+ if err == nil {
|
|
|
+ t.Error("GetAthenaRowValueFloat() should return error for nil value")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test with invalid float
|
|
|
+ rowWithInvalid := types.Row{
|
|
|
+ Data: []types.Datum{
|
|
|
+ {VarCharValue: stringPtr("not-a-number")},
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = GetAthenaRowValueFloat(rowWithInvalid, queryColumnIndexes, "test-column")
|
|
|
+ if err == nil {
|
|
|
+ t.Error("GetAthenaRowValueFloat() should return error for invalid float")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSelectAWSCategory(t *testing.T) {
|
|
|
+ // Test network category (usage type ending in "Bytes")
|
|
|
+ category := SelectAWSCategory("", "DataTransfer-Bytes", "")
|
|
|
+ if category != "Network" {
|
|
|
+ t.Errorf("SelectAWSCategory() for network = %v, want %v", category, "Network")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test compute category (provider ID with "i-" prefix)
|
|
|
+ category = SelectAWSCategory("i-123456789", "", "")
|
|
|
+ if category != "Compute" {
|
|
|
+ t.Errorf("SelectAWSCategory() for compute = %v, want %v", category, "Compute")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test GuardDuty special case
|
|
|
+ category = SelectAWSCategory("i-123456789", "", "AmazonGuardDuty")
|
|
|
+ if category != "Other" {
|
|
|
+ t.Errorf("SelectAWSCategory() for GuardDuty = %v, want %v", category, "Other")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test storage category (provider ID with "vol-" prefix)
|
|
|
+ category = SelectAWSCategory("vol-123456789", "", "")
|
|
|
+ if category != "Storage" {
|
|
|
+ t.Errorf("SelectAWSCategory() for storage = %v, want %v", category, "Storage")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test service-based categories
|
|
|
+ category = SelectAWSCategory("", "", "AmazonEKS")
|
|
|
+ if category != "Management" {
|
|
|
+ t.Errorf("SelectAWSCategory() for EKS = %v, want %v", category, "Management")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test fargate pod in EKS
|
|
|
+ category = SelectAWSCategory("arn:aws:eks:us-west-2:123456789012:pod/cluster-name/pod-name", "", "AmazonEKS")
|
|
|
+ if category != "Compute" {
|
|
|
+ t.Errorf("SelectAWSCategory() for EKS fargate pod = %v, want %v", category, "Compute")
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test other category as default
|
|
|
+ category = SelectAWSCategory("", "", "SomeUnknownService")
|
|
|
+ if category != "Other" {
|
|
|
+ t.Errorf("SelectAWSCategory() for unknown service = %v, want %v", category, "Other")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestParseARN(t *testing.T) {
|
|
|
+ // Test valid ARN
|
|
|
+ id := "arn:aws:elasticloadbalancing:us-east-1:297945954695:loadbalancer/a406f7761142e4ef58a8f2ba478d2db2"
|
|
|
+ expected := "a406f7761142e4ef58a8f2ba478d2db2"
|
|
|
+ result := ParseARN(id)
|
|
|
+ if result != expected {
|
|
|
+ t.Errorf("ParseARN() = %v, want %v", result, expected)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test invalid ARN (no match)
|
|
|
+ id = "not-an-arn"
|
|
|
+ result = ParseARN(id)
|
|
|
+ if result != id {
|
|
|
+ t.Errorf("ParseARN() for invalid ARN = %v, want %v", result, id)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Test empty string
|
|
|
+ id = ""
|
|
|
+ result = ParseARN(id)
|
|
|
+ if result != id {
|
|
|
+ t.Errorf("ParseARN() for empty string = %v, want %v", result, id)
|
|
|
+ }
|
|
|
+}
|