| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956 |
- package aws
- import (
- "context"
- "errors"
- "testing"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/athena"
- "github.com/aws/aws-sdk-go-v2/service/athena/types"
- "github.com/opencost/opencost/pkg/cloud"
- )
- func TestAthenaQuerier_GetColumns(t *testing.T) {
- // Create mock client
- mockClient := &MockAthenaClient{}
- // Create mock querier with valid configuration
- querier := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- // Test successful column retrieval
- t.Run("successful_column_retrieval", func(t *testing.T) {
- // Mock successful query results with column names
- // First row is header, subsequent rows are data
- mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
- return &athena.GetQueryResultsOutput{
- ResultSet: &types.ResultSet{
- Rows: []types.Row{
- {Data: []types.Datum{{VarCharValue: aws.String("column_name")}}}, // Header row
- {Data: []types.Datum{{VarCharValue: aws.String("column1")}}},
- {Data: []types.Datum{{VarCharValue: aws.String("column2")}}},
- {Data: []types.Datum{{VarCharValue: aws.String("column3")}}},
- },
- },
- }, nil
- }
- columns, err := querier.GetColumns()
- if err != nil {
- t.Errorf("GetColumns() returned error: %v", err)
- }
- expectedColumns := map[string]bool{
- "column1": true,
- "column2": true,
- "column3": true,
- }
- if len(columns) != len(expectedColumns) {
- t.Errorf("GetColumns() returned %d columns, want %d", len(columns), len(expectedColumns))
- }
- for col := range expectedColumns {
- if !columns[col] {
- t.Errorf("GetColumns() missing expected column: %s", col)
- }
- }
- })
- // Test empty results
- t.Run("empty_results", func(t *testing.T) {
- mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
- return &athena.GetQueryResultsOutput{
- ResultSet: &types.ResultSet{
- Rows: []types.Row{
- {Data: []types.Datum{{VarCharValue: aws.String("column_name")}}}, // Header row only
- },
- },
- }, nil
- }
- columns, err := querier.GetColumns()
- if err != nil {
- t.Errorf("GetColumns() returned error: %v", err)
- }
- if len(columns) != 0 {
- t.Errorf("GetColumns() returned %d columns, want 0", len(columns))
- }
- })
- // Test query error
- t.Run("query_error", func(t *testing.T) {
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- return nil, errors.New("query execution failed")
- }
- columns, err := querier.GetColumns()
- if err == nil {
- t.Error("GetColumns() should return error when query fails")
- }
- if len(columns) != 0 {
- t.Errorf("GetColumns() should return empty map on error, got %d columns", len(columns))
- }
- })
- }
- func TestAthenaQuerier_Query(t *testing.T) {
- // Create mock client
- mockClient := &MockAthenaClient{}
- // Create mock querier with valid configuration
- querier := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- // Test successful query
- t.Run("successful_query", func(t *testing.T) {
- queryExecuted := false
- queryString := "SELECT * FROM test_table"
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- if *params.QueryString != queryString {
- t.Errorf("Expected query string %s, got %s", queryString, *params.QueryString)
- }
- queryExecuted = true
- return &athena.StartQueryExecutionOutput{
- QueryExecutionId: aws.String("test-query-id"),
- }, nil
- }
- mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
- return &athena.GetQueryResultsOutput{
- ResultSet: &types.ResultSet{
- Rows: []types.Row{
- {Data: []types.Datum{{VarCharValue: aws.String("header")}}}, // Header row
- {Data: []types.Datum{{VarCharValue: aws.String("test-data")}}},
- },
- },
- }, nil
- }
- rowsProcessed := 0
- queryFunc := GetAthenaQueryFunc(func(row types.Row) {
- rowsProcessed++
- })
- err := querier.Query(context.Background(), queryString, queryFunc)
- if err != nil {
- t.Errorf("Query() returned error: %v", err)
- }
- if !queryExecuted {
- t.Error("Query execution was not called")
- }
- if rowsProcessed != 1 {
- t.Errorf("Expected 1 row to be processed, got %d", rowsProcessed)
- }
- // Check connection status is successful
- if querier.ConnectionStatus != cloud.SuccessfulConnection {
- t.Errorf("Expected connection status to be SuccessfulConnection, got %s", querier.ConnectionStatus)
- }
- })
- // Test invalid configuration
- t.Run("invalid_configuration", func(t *testing.T) {
- invalidQuerier := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- // Missing required fields
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- err := invalidQuerier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
- if err == nil {
- t.Error("Query() should return error for invalid configuration")
- }
- if invalidQuerier.ConnectionStatus != cloud.InvalidConfiguration {
- t.Errorf("Expected connection status to be InvalidConfiguration, got %s", invalidQuerier.ConnectionStatus)
- }
- })
- // Test query execution failure
- t.Run("query_execution_failure", func(t *testing.T) {
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- return nil, errors.New("query execution failed")
- }
- err := querier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
- if err == nil {
- t.Error("Query() should return error when query execution fails")
- }
- if querier.ConnectionStatus != cloud.FailedConnection {
- t.Errorf("Expected connection status to be FailedConnection, got %s", querier.ConnectionStatus)
- }
- })
- // Test query waiting failure
- t.Run("query_waiting_failure", func(t *testing.T) {
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- return &athena.StartQueryExecutionOutput{
- QueryExecutionId: aws.String("test-query-id"),
- }, nil
- }
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: types.QueryExecutionStateFailed,
- },
- },
- }, nil
- }
- err := querier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
- if err == nil {
- t.Error("Query() should return error when query waiting fails")
- }
- if querier.ConnectionStatus != cloud.FailedConnection {
- t.Errorf("Expected connection status to be FailedConnection, got %s", querier.ConnectionStatus)
- }
- })
- }
- func TestAthenaQuerier_GetAthenaClient(t *testing.T) {
- // Test successful client creation
- t.Run("successful_client_creation", func(t *testing.T) {
- querier := &AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- }
- client, err := querier.GetAthenaClient()
- if err != nil {
- t.Errorf("GetAthenaClient() returned error: %v", err)
- }
- if client == nil {
- t.Error("GetAthenaClient() returned nil client")
- }
- })
- // Test client creation with service account authorizer
- t.Run("service_account_authorizer", func(t *testing.T) {
- querier := &AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &ServiceAccount{},
- },
- }
- client, err := querier.GetAthenaClient()
- if err != nil {
- t.Errorf("GetAthenaClient() with ServiceAccount returned error: %v", err)
- }
- if client == nil {
- t.Error("GetAthenaClient() returned nil client")
- }
- })
- // Test client creation with assume role authorizer
- t.Run("assume_role_authorizer", func(t *testing.T) {
- querier := &AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AssumeRole{
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- RoleARN: "arn:aws:iam::123456789012:role/test-role",
- },
- },
- }
- client, err := querier.GetAthenaClient()
- if err != nil {
- t.Errorf("GetAthenaClient() with AssumeRole returned error: %v", err)
- }
- if client == nil {
- t.Error("GetAthenaClient() returned nil client")
- }
- })
- // Test client creation failure with invalid authorizer
- t.Run("invalid_authorizer", func(t *testing.T) {
- querier := &AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "", Secret: ""}, // Invalid credentials
- },
- }
- client, err := querier.GetAthenaClient()
- if err == nil {
- t.Error("GetAthenaClient() should return error for invalid authorizer")
- }
- if client != nil {
- t.Error("GetAthenaClient() should return nil client on error")
- }
- })
- // Test client creation with different regions
- t.Run("different_regions", func(t *testing.T) {
- regions := []string{"us-east-1", "us-west-2", "eu-west-1", "ap-southeast-1"}
- for _, region := range regions {
- querier := &AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: region,
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- }
- client, err := querier.GetAthenaClient()
- if err != nil {
- t.Errorf("GetAthenaClient() for region %s returned error: %v", region, err)
- }
- if client == nil {
- t.Errorf("GetAthenaClient() for region %s returned nil client", region)
- }
- }
- })
- }
- func TestAthenaQuerier_queryAthenaPaginated(t *testing.T) {
- // Create mock client
- mockClient := &MockAthenaClient{}
- // Create mock querier with valid configuration
- querier := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- // Test successful paginated query
- t.Run("successful_paginated_query", func(t *testing.T) {
- queryString := "SELECT * FROM test_table"
- queryExecuted := false
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- if *params.QueryString != queryString {
- t.Errorf("Expected query string %s, got %s", queryString, *params.QueryString)
- }
- if *params.QueryExecutionContext.Database != "test-db" {
- t.Errorf("Expected database test-db, got %s", *params.QueryExecutionContext.Database)
- }
- if *params.ResultConfiguration.OutputLocation != "test-bucket" {
- t.Errorf("Expected bucket test-bucket, got %s", *params.ResultConfiguration.OutputLocation)
- }
- queryExecuted = true
- return &athena.StartQueryExecutionOutput{
- QueryExecutionId: aws.String("test-query-id"),
- }, nil
- }
- mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
- return &athena.GetQueryResultsOutput{
- ResultSet: &types.ResultSet{
- Rows: []types.Row{
- {Data: []types.Datum{{VarCharValue: aws.String("row1")}}},
- {Data: []types.Datum{{VarCharValue: aws.String("row2")}}},
- },
- },
- }, nil
- }
- rowsProcessed := 0
- queryFunc := func(page *athena.GetQueryResultsOutput) bool {
- for range page.ResultSet.Rows {
- rowsProcessed++
- }
- return true
- }
- err := querier.queryAthenaPaginated(context.Background(), queryString, queryFunc)
- if err != nil {
- t.Errorf("queryAthenaPaginated() returned error: %v", err)
- }
- if !queryExecuted {
- t.Error("Query execution was not called")
- }
- if rowsProcessed != 2 {
- t.Errorf("Expected 2 rows to be processed, got %d", rowsProcessed)
- }
- })
- // Test query with catalog
- t.Run("query_with_catalog", func(t *testing.T) {
- querierWithCatalog := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Catalog: "test-catalog",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- catalogSet := false
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- if params.QueryExecutionContext.Catalog != nil && *params.QueryExecutionContext.Catalog == "test-catalog" {
- catalogSet = true
- }
- return &athena.StartQueryExecutionOutput{
- QueryExecutionId: aws.String("test-query-id"),
- }, nil
- }
- err := querierWithCatalog.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
- if err != nil {
- t.Errorf("queryAthenaPaginated() with catalog returned error: %v", err)
- }
- if !catalogSet {
- t.Error("Catalog was not set in query execution context")
- }
- })
- // Test query with workgroup
- t.Run("query_with_workgroup", func(t *testing.T) {
- querierWithWorkgroup := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Workgroup: "test-workgroup",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- workgroupSet := false
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- if params.WorkGroup != nil && *params.WorkGroup == "test-workgroup" {
- workgroupSet = true
- }
- return &athena.StartQueryExecutionOutput{
- QueryExecutionId: aws.String("test-query-id"),
- }, nil
- }
- err := querierWithWorkgroup.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
- if err != nil {
- t.Errorf("queryAthenaPaginated() with workgroup returned error: %v", err)
- }
- if !workgroupSet {
- t.Error("Workgroup was not set in query execution input")
- }
- })
- // Test query execution failure
- t.Run("query_execution_failure", func(t *testing.T) {
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- return nil, errors.New("query execution failed")
- }
- err := querier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
- if err == nil {
- t.Error("queryAthenaPaginated() should return error when query execution fails")
- }
- expectedError := "QueryAthenaPaginated: start query error: query execution failed"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- // Test query waiting failure
- t.Run("query_waiting_failure", func(t *testing.T) {
- mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
- return &athena.StartQueryExecutionOutput{
- QueryExecutionId: aws.String("test-query-id"),
- }, nil
- }
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: types.QueryExecutionStateFailed,
- },
- },
- }, nil
- }
- err := querier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
- if err == nil {
- t.Error("queryAthenaPaginated() should return error when query waiting fails")
- }
- expectedError := "QueryAthenaPaginated: query execution error: no query results available for query test-query-id"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- // Test get client failure
- t.Run("get_client_failure", func(t *testing.T) {
- invalidQuerier := &FailingQueryAthenaQuerier{
- MockAthenaQuerier: MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- },
- }
- err := invalidQuerier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
- if err == nil {
- t.Error("queryAthenaPaginated() should return error when client creation fails")
- }
- expectedError := "QueryAthenaPaginated: GetAthenaClient error: failed to create client"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- }
- func TestAthenaQuerier_waitForQueryToComplete(t *testing.T) {
- // Create mock client
- mockClient := &MockAthenaClient{}
- // Create mock querier
- querier := &MockAthenaQuerier{
- AthenaQuerier: AthenaQuerier{
- AthenaConfiguration: AthenaConfiguration{
- Bucket: "test-bucket",
- Region: "us-east-1",
- Database: "test-db",
- Table: "test-table",
- Account: "123456789012",
- Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
- },
- },
- mockClient: mockClient,
- }
- // Test successful query completion
- t.Run("successful_query_completion", func(t *testing.T) {
- queryID := "test-query-id"
- callCount := 0
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- callCount++
- if *params.QueryExecutionId != queryID {
- t.Errorf("Expected query ID %s, got %s", queryID, *params.QueryExecutionId)
- }
- // Return SUCCEEDED on first call
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: types.QueryExecutionStateSucceeded,
- },
- },
- }, nil
- }
- err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
- if err != nil {
- t.Errorf("waitForQueryToComplete() returned error: %v", err)
- }
- if callCount != 1 {
- t.Errorf("Expected 1 call to GetQueryExecution, got %d", callCount)
- }
- })
- // Test query failure
- t.Run("query_failure", func(t *testing.T) {
- queryID := "test-query-id"
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: types.QueryExecutionStateFailed,
- },
- },
- }, nil
- }
- err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
- if err == nil {
- t.Error("waitForQueryToComplete() should return error when query fails")
- }
- expectedError := "no query results available for query test-query-id"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- // Test query cancellation
- t.Run("query_cancellation", func(t *testing.T) {
- queryID := "test-query-id"
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: types.QueryExecutionStateCancelled,
- },
- },
- }, nil
- }
- err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
- if err == nil {
- t.Error("waitForQueryToComplete() should return error when query is cancelled")
- }
- expectedError := "no query results available for query test-query-id"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- // Test query timeout
- t.Run("query_timeout", func(t *testing.T) {
- queryID := "test-query-id"
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: "TIMED_OUT", // Use string literal since QueryExecutionStateTimedOut doesn't exist
- },
- },
- }, nil
- }
- err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
- if err == nil {
- t.Error("waitForQueryToComplete() should return error when query times out")
- }
- expectedError := "no query results available for query test-query-id"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- // Test GetQueryExecution error
- t.Run("get_query_execution_error", func(t *testing.T) {
- queryID := "test-query-id"
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- return nil, errors.New("failed to get query execution")
- }
- err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
- if err == nil {
- t.Error("waitForQueryToComplete() should return error when GetQueryExecution fails")
- }
- expectedError := "failed to get query execution"
- if err.Error() != expectedError {
- t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
- }
- })
- // Test context cancellation
- t.Run("context_cancellation", func(t *testing.T) {
- queryID := "test-query-id"
- ctx, cancel := context.WithCancel(context.Background())
- cancel() // Cancel immediately
- mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
- // Check if context is cancelled
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- return &athena.GetQueryExecutionOutput{
- QueryExecution: &types.QueryExecution{
- Status: &types.QueryExecutionStatus{
- State: types.QueryExecutionStateSucceeded,
- },
- },
- }, nil
- }
- }
- err := querier.waitForQueryToComplete(ctx, mockClient, &queryID)
- if err == nil {
- t.Error("waitForQueryToComplete() should return error when context is cancelled")
- }
- if err != context.Canceled {
- t.Errorf("Expected context.Canceled error, got %v", err)
- }
- })
- // Test with nil query execution ID
- t.Run("nil_query_execution_id", func(t *testing.T) {
- err := querier.waitForQueryToComplete(context.Background(), mockClient, nil)
- if err == nil {
- t.Error("waitForQueryToComplete() should return error when query execution ID is nil")
- }
- })
- }
- func TestAthenaQuerier_GetAthenaQueryFunc(t *testing.T) {
- // Test that GetAthenaQueryFunc returns a function
- queryFunc := GetAthenaQueryFunc(func(row types.Row) {
- // Do nothing
- })
- if queryFunc == nil {
- t.Error("GetAthenaQueryFunc should return a non-nil function")
- }
- // Test that the returned function can be called
- result := &athena.GetQueryResultsOutput{
- ResultSet: &types.ResultSet{
- Rows: []types.Row{
- {Data: []types.Datum{}},
- },
- },
- }
- // Should not panic
- queryFunc(result)
- }
- func TestGetAthenaRowValue(t *testing.T) {
- // Test with valid data
- row := types.Row{
- Data: []types.Datum{
- {VarCharValue: stringPtr("test-value")},
- },
- }
- queryColumnIndexes := map[string]int{
- "test-column": 0,
- }
- result := GetAthenaRowValue(row, queryColumnIndexes, "test-column")
- if result != "test-value" {
- t.Errorf("GetAthenaRowValue() = %v, want %v", result, "test-value")
- }
- // Test with missing column
- result = GetAthenaRowValue(row, queryColumnIndexes, "missing-column")
- if result != "" {
- t.Errorf("GetAthenaRowValue() with missing column = %v, want %v", result, "")
- }
- // Test with nil value
- rowWithNil := types.Row{
- Data: []types.Datum{
- {VarCharValue: nil},
- },
- }
- result = GetAthenaRowValue(rowWithNil, queryColumnIndexes, "test-column")
- if result != "" {
- t.Errorf("GetAthenaRowValue() with nil value = %v, want %v", result, "")
- }
- }
- func TestGetAthenaRowValueFloat(t *testing.T) {
- // Test with valid data
- row := types.Row{
- Data: []types.Datum{
- {VarCharValue: stringPtr("3.14159")},
- },
- }
- queryColumnIndexes := map[string]int{
- "test-column": 0,
- }
- result, err := GetAthenaRowValueFloat(row, queryColumnIndexes, "test-column")
- if err != nil {
- t.Errorf("GetAthenaRowValueFloat() returned error: %v", err)
- }
- if result != 3.14159 {
- t.Errorf("GetAthenaRowValueFloat() = %v, want %v", result, 3.14159)
- }
- // Test with missing column
- _, err = GetAthenaRowValueFloat(row, queryColumnIndexes, "missing-column")
- if err == nil {
- t.Error("GetAthenaRowValueFloat() should return error for missing column")
- }
- // Test with nil value
- rowWithNil := types.Row{
- Data: []types.Datum{
- {VarCharValue: nil},
- },
- }
- _, err = GetAthenaRowValueFloat(rowWithNil, queryColumnIndexes, "test-column")
- if err == nil {
- t.Error("GetAthenaRowValueFloat() should return error for nil value")
- }
- // Test with invalid float
- rowWithInvalid := types.Row{
- Data: []types.Datum{
- {VarCharValue: stringPtr("not-a-number")},
- },
- }
- _, err = GetAthenaRowValueFloat(rowWithInvalid, queryColumnIndexes, "test-column")
- if err == nil {
- t.Error("GetAthenaRowValueFloat() should return error for invalid float")
- }
- }
- func TestSelectAWSCategory(t *testing.T) {
- // Test network category (usage type ending in "Bytes")
- category := SelectAWSCategory("", "DataTransfer-Bytes", "")
- if category != "Network" {
- t.Errorf("SelectAWSCategory() for network = %v, want %v", category, "Network")
- }
- // Test compute category (provider ID with "i-" prefix)
- category = SelectAWSCategory("i-123456789", "", "")
- if category != "Compute" {
- t.Errorf("SelectAWSCategory() for compute = %v, want %v", category, "Compute")
- }
- // Test GuardDuty special case
- category = SelectAWSCategory("i-123456789", "", "AmazonGuardDuty")
- if category != "Other" {
- t.Errorf("SelectAWSCategory() for GuardDuty = %v, want %v", category, "Other")
- }
- // Test storage category (provider ID with "vol-" prefix)
- category = SelectAWSCategory("vol-123456789", "", "")
- if category != "Storage" {
- t.Errorf("SelectAWSCategory() for storage = %v, want %v", category, "Storage")
- }
- // Test service-based categories
- category = SelectAWSCategory("", "", "AmazonEKS")
- if category != "Management" {
- t.Errorf("SelectAWSCategory() for EKS = %v, want %v", category, "Management")
- }
- // Test fargate pod in EKS
- category = SelectAWSCategory("arn:aws:eks:us-west-2:123456789012:pod/cluster-name/pod-name", "", "AmazonEKS")
- if category != "Compute" {
- t.Errorf("SelectAWSCategory() for EKS fargate pod = %v, want %v", category, "Compute")
- }
- // Test other category as default
- category = SelectAWSCategory("", "", "SomeUnknownService")
- if category != "Other" {
- t.Errorf("SelectAWSCategory() for unknown service = %v, want %v", category, "Other")
- }
- }
- func TestParseARN(t *testing.T) {
- // Test valid ARN
- id := "arn:aws:elasticloadbalancing:us-east-1:297945954695:loadbalancer/a406f7761142e4ef58a8f2ba478d2db2"
- expected := "a406f7761142e4ef58a8f2ba478d2db2"
- result := ParseARN(id)
- if result != expected {
- t.Errorf("ParseARN() = %v, want %v", result, expected)
- }
- // Test invalid ARN (no match)
- id = "not-an-arn"
- result = ParseARN(id)
- if result != id {
- t.Errorf("ParseARN() for invalid ARN = %v, want %v", result, id)
- }
- // Test empty string
- id = ""
- result = ParseARN(id)
- if result != id {
- t.Errorf("ParseARN() for empty string = %v, want %v", result, id)
- }
- }
|