2
0

athenaquerier_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956
  1. package aws
  2. import (
  3. "context"
  4. "errors"
  5. "testing"
  6. "github.com/aws/aws-sdk-go-v2/aws"
  7. "github.com/aws/aws-sdk-go-v2/service/athena"
  8. "github.com/aws/aws-sdk-go-v2/service/athena/types"
  9. "github.com/opencost/opencost/pkg/cloud"
  10. )
  11. func TestAthenaQuerier_GetColumns(t *testing.T) {
  12. // Create mock client
  13. mockClient := &MockAthenaClient{}
  14. // Create mock querier with valid configuration
  15. querier := &MockAthenaQuerier{
  16. AthenaQuerier: AthenaQuerier{
  17. AthenaConfiguration: AthenaConfiguration{
  18. Bucket: "test-bucket",
  19. Region: "us-east-1",
  20. Database: "test-db",
  21. Table: "test-table",
  22. Account: "123456789012",
  23. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  24. },
  25. },
  26. mockClient: mockClient,
  27. }
  28. // Test successful column retrieval
  29. t.Run("successful_column_retrieval", func(t *testing.T) {
  30. // Mock successful query results with column names
  31. // First row is header, subsequent rows are data
  32. mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
  33. return &athena.GetQueryResultsOutput{
  34. ResultSet: &types.ResultSet{
  35. Rows: []types.Row{
  36. {Data: []types.Datum{{VarCharValue: aws.String("column_name")}}}, // Header row
  37. {Data: []types.Datum{{VarCharValue: aws.String("column1")}}},
  38. {Data: []types.Datum{{VarCharValue: aws.String("column2")}}},
  39. {Data: []types.Datum{{VarCharValue: aws.String("column3")}}},
  40. },
  41. },
  42. }, nil
  43. }
  44. columns, err := querier.GetColumns()
  45. if err != nil {
  46. t.Errorf("GetColumns() returned error: %v", err)
  47. }
  48. expectedColumns := map[string]bool{
  49. "column1": true,
  50. "column2": true,
  51. "column3": true,
  52. }
  53. if len(columns) != len(expectedColumns) {
  54. t.Errorf("GetColumns() returned %d columns, want %d", len(columns), len(expectedColumns))
  55. }
  56. for col := range expectedColumns {
  57. if !columns[col] {
  58. t.Errorf("GetColumns() missing expected column: %s", col)
  59. }
  60. }
  61. })
  62. // Test empty results
  63. t.Run("empty_results", func(t *testing.T) {
  64. mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
  65. return &athena.GetQueryResultsOutput{
  66. ResultSet: &types.ResultSet{
  67. Rows: []types.Row{
  68. {Data: []types.Datum{{VarCharValue: aws.String("column_name")}}}, // Header row only
  69. },
  70. },
  71. }, nil
  72. }
  73. columns, err := querier.GetColumns()
  74. if err != nil {
  75. t.Errorf("GetColumns() returned error: %v", err)
  76. }
  77. if len(columns) != 0 {
  78. t.Errorf("GetColumns() returned %d columns, want 0", len(columns))
  79. }
  80. })
  81. // Test query error
  82. t.Run("query_error", func(t *testing.T) {
  83. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  84. return nil, errors.New("query execution failed")
  85. }
  86. columns, err := querier.GetColumns()
  87. if err == nil {
  88. t.Error("GetColumns() should return error when query fails")
  89. }
  90. if len(columns) != 0 {
  91. t.Errorf("GetColumns() should return empty map on error, got %d columns", len(columns))
  92. }
  93. })
  94. }
  95. func TestAthenaQuerier_Query(t *testing.T) {
  96. // Create mock client
  97. mockClient := &MockAthenaClient{}
  98. // Create mock querier with valid configuration
  99. querier := &MockAthenaQuerier{
  100. AthenaQuerier: AthenaQuerier{
  101. AthenaConfiguration: AthenaConfiguration{
  102. Bucket: "test-bucket",
  103. Region: "us-east-1",
  104. Database: "test-db",
  105. Table: "test-table",
  106. Account: "123456789012",
  107. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  108. },
  109. },
  110. mockClient: mockClient,
  111. }
  112. // Test successful query
  113. t.Run("successful_query", func(t *testing.T) {
  114. queryExecuted := false
  115. queryString := "SELECT * FROM test_table"
  116. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  117. if *params.QueryString != queryString {
  118. t.Errorf("Expected query string %s, got %s", queryString, *params.QueryString)
  119. }
  120. queryExecuted = true
  121. return &athena.StartQueryExecutionOutput{
  122. QueryExecutionId: aws.String("test-query-id"),
  123. }, nil
  124. }
  125. mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
  126. return &athena.GetQueryResultsOutput{
  127. ResultSet: &types.ResultSet{
  128. Rows: []types.Row{
  129. {Data: []types.Datum{{VarCharValue: aws.String("header")}}}, // Header row
  130. {Data: []types.Datum{{VarCharValue: aws.String("test-data")}}},
  131. },
  132. },
  133. }, nil
  134. }
  135. rowsProcessed := 0
  136. queryFunc := GetAthenaQueryFunc(func(row types.Row) {
  137. rowsProcessed++
  138. })
  139. err := querier.Query(context.Background(), queryString, queryFunc)
  140. if err != nil {
  141. t.Errorf("Query() returned error: %v", err)
  142. }
  143. if !queryExecuted {
  144. t.Error("Query execution was not called")
  145. }
  146. if rowsProcessed != 1 {
  147. t.Errorf("Expected 1 row to be processed, got %d", rowsProcessed)
  148. }
  149. // Check connection status is successful
  150. if querier.ConnectionStatus != cloud.SuccessfulConnection {
  151. t.Errorf("Expected connection status to be SuccessfulConnection, got %s", querier.ConnectionStatus)
  152. }
  153. })
  154. // Test invalid configuration
  155. t.Run("invalid_configuration", func(t *testing.T) {
  156. invalidQuerier := &MockAthenaQuerier{
  157. AthenaQuerier: AthenaQuerier{
  158. AthenaConfiguration: AthenaConfiguration{
  159. // Missing required fields
  160. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  161. },
  162. },
  163. mockClient: mockClient,
  164. }
  165. err := invalidQuerier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
  166. if err == nil {
  167. t.Error("Query() should return error for invalid configuration")
  168. }
  169. if invalidQuerier.ConnectionStatus != cloud.InvalidConfiguration {
  170. t.Errorf("Expected connection status to be InvalidConfiguration, got %s", invalidQuerier.ConnectionStatus)
  171. }
  172. })
  173. // Test query execution failure
  174. t.Run("query_execution_failure", func(t *testing.T) {
  175. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  176. return nil, errors.New("query execution failed")
  177. }
  178. err := querier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
  179. if err == nil {
  180. t.Error("Query() should return error when query execution fails")
  181. }
  182. if querier.ConnectionStatus != cloud.FailedConnection {
  183. t.Errorf("Expected connection status to be FailedConnection, got %s", querier.ConnectionStatus)
  184. }
  185. })
  186. // Test query waiting failure
  187. t.Run("query_waiting_failure", func(t *testing.T) {
  188. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  189. return &athena.StartQueryExecutionOutput{
  190. QueryExecutionId: aws.String("test-query-id"),
  191. }, nil
  192. }
  193. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  194. return &athena.GetQueryExecutionOutput{
  195. QueryExecution: &types.QueryExecution{
  196. Status: &types.QueryExecutionStatus{
  197. State: types.QueryExecutionStateFailed,
  198. },
  199. },
  200. }, nil
  201. }
  202. err := querier.Query(context.Background(), "SELECT * FROM test", GetAthenaQueryFunc(func(row types.Row) {}))
  203. if err == nil {
  204. t.Error("Query() should return error when query waiting fails")
  205. }
  206. if querier.ConnectionStatus != cloud.FailedConnection {
  207. t.Errorf("Expected connection status to be FailedConnection, got %s", querier.ConnectionStatus)
  208. }
  209. })
  210. }
  211. func TestAthenaQuerier_GetAthenaClient(t *testing.T) {
  212. // Test successful client creation
  213. t.Run("successful_client_creation", func(t *testing.T) {
  214. querier := &AthenaQuerier{
  215. AthenaConfiguration: AthenaConfiguration{
  216. Bucket: "test-bucket",
  217. Region: "us-east-1",
  218. Database: "test-db",
  219. Table: "test-table",
  220. Account: "123456789012",
  221. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  222. },
  223. }
  224. client, err := querier.GetAthenaClient()
  225. if err != nil {
  226. t.Errorf("GetAthenaClient() returned error: %v", err)
  227. }
  228. if client == nil {
  229. t.Error("GetAthenaClient() returned nil client")
  230. }
  231. })
  232. // Test client creation with service account authorizer
  233. t.Run("service_account_authorizer", func(t *testing.T) {
  234. querier := &AthenaQuerier{
  235. AthenaConfiguration: AthenaConfiguration{
  236. Bucket: "test-bucket",
  237. Region: "us-east-1",
  238. Database: "test-db",
  239. Table: "test-table",
  240. Account: "123456789012",
  241. Authorizer: &ServiceAccount{},
  242. },
  243. }
  244. client, err := querier.GetAthenaClient()
  245. if err != nil {
  246. t.Errorf("GetAthenaClient() with ServiceAccount returned error: %v", err)
  247. }
  248. if client == nil {
  249. t.Error("GetAthenaClient() returned nil client")
  250. }
  251. })
  252. // Test client creation with assume role authorizer
  253. t.Run("assume_role_authorizer", func(t *testing.T) {
  254. querier := &AthenaQuerier{
  255. AthenaConfiguration: AthenaConfiguration{
  256. Bucket: "test-bucket",
  257. Region: "us-east-1",
  258. Database: "test-db",
  259. Table: "test-table",
  260. Account: "123456789012",
  261. Authorizer: &AssumeRole{
  262. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  263. RoleARN: "arn:aws:iam::123456789012:role/test-role",
  264. },
  265. },
  266. }
  267. client, err := querier.GetAthenaClient()
  268. if err != nil {
  269. t.Errorf("GetAthenaClient() with AssumeRole returned error: %v", err)
  270. }
  271. if client == nil {
  272. t.Error("GetAthenaClient() returned nil client")
  273. }
  274. })
  275. // Test client creation failure with invalid authorizer
  276. t.Run("invalid_authorizer", func(t *testing.T) {
  277. querier := &AthenaQuerier{
  278. AthenaConfiguration: AthenaConfiguration{
  279. Bucket: "test-bucket",
  280. Region: "us-east-1",
  281. Database: "test-db",
  282. Table: "test-table",
  283. Account: "123456789012",
  284. Authorizer: &AccessKey{ID: "", Secret: ""}, // Invalid credentials
  285. },
  286. }
  287. client, err := querier.GetAthenaClient()
  288. if err == nil {
  289. t.Error("GetAthenaClient() should return error for invalid authorizer")
  290. }
  291. if client != nil {
  292. t.Error("GetAthenaClient() should return nil client on error")
  293. }
  294. })
  295. // Test client creation with different regions
  296. t.Run("different_regions", func(t *testing.T) {
  297. regions := []string{"us-east-1", "us-west-2", "eu-west-1", "ap-southeast-1"}
  298. for _, region := range regions {
  299. querier := &AthenaQuerier{
  300. AthenaConfiguration: AthenaConfiguration{
  301. Bucket: "test-bucket",
  302. Region: region,
  303. Database: "test-db",
  304. Table: "test-table",
  305. Account: "123456789012",
  306. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  307. },
  308. }
  309. client, err := querier.GetAthenaClient()
  310. if err != nil {
  311. t.Errorf("GetAthenaClient() for region %s returned error: %v", region, err)
  312. }
  313. if client == nil {
  314. t.Errorf("GetAthenaClient() for region %s returned nil client", region)
  315. }
  316. }
  317. })
  318. }
  319. func TestAthenaQuerier_queryAthenaPaginated(t *testing.T) {
  320. // Create mock client
  321. mockClient := &MockAthenaClient{}
  322. // Create mock querier with valid configuration
  323. querier := &MockAthenaQuerier{
  324. AthenaQuerier: AthenaQuerier{
  325. AthenaConfiguration: AthenaConfiguration{
  326. Bucket: "test-bucket",
  327. Region: "us-east-1",
  328. Database: "test-db",
  329. Table: "test-table",
  330. Account: "123456789012",
  331. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  332. },
  333. },
  334. mockClient: mockClient,
  335. }
  336. // Test successful paginated query
  337. t.Run("successful_paginated_query", func(t *testing.T) {
  338. queryString := "SELECT * FROM test_table"
  339. queryExecuted := false
  340. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  341. if *params.QueryString != queryString {
  342. t.Errorf("Expected query string %s, got %s", queryString, *params.QueryString)
  343. }
  344. if *params.QueryExecutionContext.Database != "test-db" {
  345. t.Errorf("Expected database test-db, got %s", *params.QueryExecutionContext.Database)
  346. }
  347. if *params.ResultConfiguration.OutputLocation != "test-bucket" {
  348. t.Errorf("Expected bucket test-bucket, got %s", *params.ResultConfiguration.OutputLocation)
  349. }
  350. queryExecuted = true
  351. return &athena.StartQueryExecutionOutput{
  352. QueryExecutionId: aws.String("test-query-id"),
  353. }, nil
  354. }
  355. mockClient.GetQueryResultsFunc = func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
  356. return &athena.GetQueryResultsOutput{
  357. ResultSet: &types.ResultSet{
  358. Rows: []types.Row{
  359. {Data: []types.Datum{{VarCharValue: aws.String("row1")}}},
  360. {Data: []types.Datum{{VarCharValue: aws.String("row2")}}},
  361. },
  362. },
  363. }, nil
  364. }
  365. rowsProcessed := 0
  366. queryFunc := func(page *athena.GetQueryResultsOutput) bool {
  367. for range page.ResultSet.Rows {
  368. rowsProcessed++
  369. }
  370. return true
  371. }
  372. err := querier.queryAthenaPaginated(context.Background(), queryString, queryFunc)
  373. if err != nil {
  374. t.Errorf("queryAthenaPaginated() returned error: %v", err)
  375. }
  376. if !queryExecuted {
  377. t.Error("Query execution was not called")
  378. }
  379. if rowsProcessed != 2 {
  380. t.Errorf("Expected 2 rows to be processed, got %d", rowsProcessed)
  381. }
  382. })
  383. // Test query with catalog
  384. t.Run("query_with_catalog", func(t *testing.T) {
  385. querierWithCatalog := &MockAthenaQuerier{
  386. AthenaQuerier: AthenaQuerier{
  387. AthenaConfiguration: AthenaConfiguration{
  388. Bucket: "test-bucket",
  389. Region: "us-east-1",
  390. Database: "test-db",
  391. Catalog: "test-catalog",
  392. Table: "test-table",
  393. Account: "123456789012",
  394. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  395. },
  396. },
  397. mockClient: mockClient,
  398. }
  399. catalogSet := false
  400. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  401. if params.QueryExecutionContext.Catalog != nil && *params.QueryExecutionContext.Catalog == "test-catalog" {
  402. catalogSet = true
  403. }
  404. return &athena.StartQueryExecutionOutput{
  405. QueryExecutionId: aws.String("test-query-id"),
  406. }, nil
  407. }
  408. err := querierWithCatalog.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
  409. if err != nil {
  410. t.Errorf("queryAthenaPaginated() with catalog returned error: %v", err)
  411. }
  412. if !catalogSet {
  413. t.Error("Catalog was not set in query execution context")
  414. }
  415. })
  416. // Test query with workgroup
  417. t.Run("query_with_workgroup", func(t *testing.T) {
  418. querierWithWorkgroup := &MockAthenaQuerier{
  419. AthenaQuerier: AthenaQuerier{
  420. AthenaConfiguration: AthenaConfiguration{
  421. Bucket: "test-bucket",
  422. Region: "us-east-1",
  423. Database: "test-db",
  424. Table: "test-table",
  425. Workgroup: "test-workgroup",
  426. Account: "123456789012",
  427. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  428. },
  429. },
  430. mockClient: mockClient,
  431. }
  432. workgroupSet := false
  433. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  434. if params.WorkGroup != nil && *params.WorkGroup == "test-workgroup" {
  435. workgroupSet = true
  436. }
  437. return &athena.StartQueryExecutionOutput{
  438. QueryExecutionId: aws.String("test-query-id"),
  439. }, nil
  440. }
  441. err := querierWithWorkgroup.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
  442. if err != nil {
  443. t.Errorf("queryAthenaPaginated() with workgroup returned error: %v", err)
  444. }
  445. if !workgroupSet {
  446. t.Error("Workgroup was not set in query execution input")
  447. }
  448. })
  449. // Test query execution failure
  450. t.Run("query_execution_failure", func(t *testing.T) {
  451. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  452. return nil, errors.New("query execution failed")
  453. }
  454. err := querier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
  455. if err == nil {
  456. t.Error("queryAthenaPaginated() should return error when query execution fails")
  457. }
  458. expectedError := "QueryAthenaPaginated: start query error: query execution failed"
  459. if err.Error() != expectedError {
  460. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  461. }
  462. })
  463. // Test query waiting failure
  464. t.Run("query_waiting_failure", func(t *testing.T) {
  465. mockClient.StartQueryExecutionFunc = func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  466. return &athena.StartQueryExecutionOutput{
  467. QueryExecutionId: aws.String("test-query-id"),
  468. }, nil
  469. }
  470. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  471. return &athena.GetQueryExecutionOutput{
  472. QueryExecution: &types.QueryExecution{
  473. Status: &types.QueryExecutionStatus{
  474. State: types.QueryExecutionStateFailed,
  475. },
  476. },
  477. }, nil
  478. }
  479. err := querier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
  480. if err == nil {
  481. t.Error("queryAthenaPaginated() should return error when query waiting fails")
  482. }
  483. expectedError := "QueryAthenaPaginated: query execution error: no query results available for query test-query-id"
  484. if err.Error() != expectedError {
  485. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  486. }
  487. })
  488. // Test get client failure
  489. t.Run("get_client_failure", func(t *testing.T) {
  490. invalidQuerier := &FailingQueryAthenaQuerier{
  491. MockAthenaQuerier: MockAthenaQuerier{
  492. AthenaQuerier: AthenaQuerier{
  493. AthenaConfiguration: AthenaConfiguration{
  494. Bucket: "test-bucket",
  495. Region: "us-east-1",
  496. Database: "test-db",
  497. Table: "test-table",
  498. Account: "123456789012",
  499. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  500. },
  501. },
  502. mockClient: mockClient,
  503. },
  504. }
  505. err := invalidQuerier.queryAthenaPaginated(context.Background(), "SELECT * FROM test", func(page *athena.GetQueryResultsOutput) bool { return true })
  506. if err == nil {
  507. t.Error("queryAthenaPaginated() should return error when client creation fails")
  508. }
  509. expectedError := "QueryAthenaPaginated: GetAthenaClient error: failed to create client"
  510. if err.Error() != expectedError {
  511. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  512. }
  513. })
  514. }
  515. func TestAthenaQuerier_waitForQueryToComplete(t *testing.T) {
  516. // Create mock client
  517. mockClient := &MockAthenaClient{}
  518. // Create mock querier
  519. querier := &MockAthenaQuerier{
  520. AthenaQuerier: AthenaQuerier{
  521. AthenaConfiguration: AthenaConfiguration{
  522. Bucket: "test-bucket",
  523. Region: "us-east-1",
  524. Database: "test-db",
  525. Table: "test-table",
  526. Account: "123456789012",
  527. Authorizer: &AccessKey{ID: "test-key", Secret: "test-secret"},
  528. },
  529. },
  530. mockClient: mockClient,
  531. }
  532. // Test successful query completion
  533. t.Run("successful_query_completion", func(t *testing.T) {
  534. queryID := "test-query-id"
  535. callCount := 0
  536. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  537. callCount++
  538. if *params.QueryExecutionId != queryID {
  539. t.Errorf("Expected query ID %s, got %s", queryID, *params.QueryExecutionId)
  540. }
  541. // Return SUCCEEDED on first call
  542. return &athena.GetQueryExecutionOutput{
  543. QueryExecution: &types.QueryExecution{
  544. Status: &types.QueryExecutionStatus{
  545. State: types.QueryExecutionStateSucceeded,
  546. },
  547. },
  548. }, nil
  549. }
  550. err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
  551. if err != nil {
  552. t.Errorf("waitForQueryToComplete() returned error: %v", err)
  553. }
  554. if callCount != 1 {
  555. t.Errorf("Expected 1 call to GetQueryExecution, got %d", callCount)
  556. }
  557. })
  558. // Test query failure
  559. t.Run("query_failure", func(t *testing.T) {
  560. queryID := "test-query-id"
  561. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  562. return &athena.GetQueryExecutionOutput{
  563. QueryExecution: &types.QueryExecution{
  564. Status: &types.QueryExecutionStatus{
  565. State: types.QueryExecutionStateFailed,
  566. },
  567. },
  568. }, nil
  569. }
  570. err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
  571. if err == nil {
  572. t.Error("waitForQueryToComplete() should return error when query fails")
  573. }
  574. expectedError := "no query results available for query test-query-id"
  575. if err.Error() != expectedError {
  576. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  577. }
  578. })
  579. // Test query cancellation
  580. t.Run("query_cancellation", func(t *testing.T) {
  581. queryID := "test-query-id"
  582. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  583. return &athena.GetQueryExecutionOutput{
  584. QueryExecution: &types.QueryExecution{
  585. Status: &types.QueryExecutionStatus{
  586. State: types.QueryExecutionStateCancelled,
  587. },
  588. },
  589. }, nil
  590. }
  591. err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
  592. if err == nil {
  593. t.Error("waitForQueryToComplete() should return error when query is cancelled")
  594. }
  595. expectedError := "no query results available for query test-query-id"
  596. if err.Error() != expectedError {
  597. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  598. }
  599. })
  600. // Test query timeout
  601. t.Run("query_timeout", func(t *testing.T) {
  602. queryID := "test-query-id"
  603. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  604. return &athena.GetQueryExecutionOutput{
  605. QueryExecution: &types.QueryExecution{
  606. Status: &types.QueryExecutionStatus{
  607. State: "TIMED_OUT", // Use string literal since QueryExecutionStateTimedOut doesn't exist
  608. },
  609. },
  610. }, nil
  611. }
  612. err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
  613. if err == nil {
  614. t.Error("waitForQueryToComplete() should return error when query times out")
  615. }
  616. expectedError := "no query results available for query test-query-id"
  617. if err.Error() != expectedError {
  618. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  619. }
  620. })
  621. // Test GetQueryExecution error
  622. t.Run("get_query_execution_error", func(t *testing.T) {
  623. queryID := "test-query-id"
  624. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  625. return nil, errors.New("failed to get query execution")
  626. }
  627. err := querier.waitForQueryToComplete(context.Background(), mockClient, &queryID)
  628. if err == nil {
  629. t.Error("waitForQueryToComplete() should return error when GetQueryExecution fails")
  630. }
  631. expectedError := "failed to get query execution"
  632. if err.Error() != expectedError {
  633. t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
  634. }
  635. })
  636. // Test context cancellation
  637. t.Run("context_cancellation", func(t *testing.T) {
  638. queryID := "test-query-id"
  639. ctx, cancel := context.WithCancel(context.Background())
  640. cancel() // Cancel immediately
  641. mockClient.GetQueryExecutionFunc = func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  642. // Check if context is cancelled
  643. select {
  644. case <-ctx.Done():
  645. return nil, ctx.Err()
  646. default:
  647. return &athena.GetQueryExecutionOutput{
  648. QueryExecution: &types.QueryExecution{
  649. Status: &types.QueryExecutionStatus{
  650. State: types.QueryExecutionStateSucceeded,
  651. },
  652. },
  653. }, nil
  654. }
  655. }
  656. err := querier.waitForQueryToComplete(ctx, mockClient, &queryID)
  657. if err == nil {
  658. t.Error("waitForQueryToComplete() should return error when context is cancelled")
  659. }
  660. if err != context.Canceled {
  661. t.Errorf("Expected context.Canceled error, got %v", err)
  662. }
  663. })
  664. // Test with nil query execution ID
  665. t.Run("nil_query_execution_id", func(t *testing.T) {
  666. err := querier.waitForQueryToComplete(context.Background(), mockClient, nil)
  667. if err == nil {
  668. t.Error("waitForQueryToComplete() should return error when query execution ID is nil")
  669. }
  670. })
  671. }
  672. func TestAthenaQuerier_GetAthenaQueryFunc(t *testing.T) {
  673. // Test that GetAthenaQueryFunc returns a function
  674. queryFunc := GetAthenaQueryFunc(func(row types.Row) {
  675. // Do nothing
  676. })
  677. if queryFunc == nil {
  678. t.Error("GetAthenaQueryFunc should return a non-nil function")
  679. }
  680. // Test that the returned function can be called
  681. result := &athena.GetQueryResultsOutput{
  682. ResultSet: &types.ResultSet{
  683. Rows: []types.Row{
  684. {Data: []types.Datum{}},
  685. },
  686. },
  687. }
  688. // Should not panic
  689. queryFunc(result)
  690. }
  691. func TestGetAthenaRowValue(t *testing.T) {
  692. // Test with valid data
  693. row := types.Row{
  694. Data: []types.Datum{
  695. {VarCharValue: stringPtr("test-value")},
  696. },
  697. }
  698. queryColumnIndexes := map[string]int{
  699. "test-column": 0,
  700. }
  701. result := GetAthenaRowValue(row, queryColumnIndexes, "test-column")
  702. if result != "test-value" {
  703. t.Errorf("GetAthenaRowValue() = %v, want %v", result, "test-value")
  704. }
  705. // Test with missing column
  706. result = GetAthenaRowValue(row, queryColumnIndexes, "missing-column")
  707. if result != "" {
  708. t.Errorf("GetAthenaRowValue() with missing column = %v, want %v", result, "")
  709. }
  710. // Test with nil value
  711. rowWithNil := types.Row{
  712. Data: []types.Datum{
  713. {VarCharValue: nil},
  714. },
  715. }
  716. result = GetAthenaRowValue(rowWithNil, queryColumnIndexes, "test-column")
  717. if result != "" {
  718. t.Errorf("GetAthenaRowValue() with nil value = %v, want %v", result, "")
  719. }
  720. }
  721. func TestGetAthenaRowValueFloat(t *testing.T) {
  722. // Test with valid data
  723. row := types.Row{
  724. Data: []types.Datum{
  725. {VarCharValue: stringPtr("3.14159")},
  726. },
  727. }
  728. queryColumnIndexes := map[string]int{
  729. "test-column": 0,
  730. }
  731. result, err := GetAthenaRowValueFloat(row, queryColumnIndexes, "test-column")
  732. if err != nil {
  733. t.Errorf("GetAthenaRowValueFloat() returned error: %v", err)
  734. }
  735. if result != 3.14159 {
  736. t.Errorf("GetAthenaRowValueFloat() = %v, want %v", result, 3.14159)
  737. }
  738. // Test with missing column
  739. _, err = GetAthenaRowValueFloat(row, queryColumnIndexes, "missing-column")
  740. if err == nil {
  741. t.Error("GetAthenaRowValueFloat() should return error for missing column")
  742. }
  743. // Test with nil value
  744. rowWithNil := types.Row{
  745. Data: []types.Datum{
  746. {VarCharValue: nil},
  747. },
  748. }
  749. _, err = GetAthenaRowValueFloat(rowWithNil, queryColumnIndexes, "test-column")
  750. if err == nil {
  751. t.Error("GetAthenaRowValueFloat() should return error for nil value")
  752. }
  753. // Test with invalid float
  754. rowWithInvalid := types.Row{
  755. Data: []types.Datum{
  756. {VarCharValue: stringPtr("not-a-number")},
  757. },
  758. }
  759. _, err = GetAthenaRowValueFloat(rowWithInvalid, queryColumnIndexes, "test-column")
  760. if err == nil {
  761. t.Error("GetAthenaRowValueFloat() should return error for invalid float")
  762. }
  763. }
  764. func TestSelectAWSCategory(t *testing.T) {
  765. // Test network category (usage type ending in "Bytes")
  766. category := SelectAWSCategory("", "DataTransfer-Bytes", "")
  767. if category != "Network" {
  768. t.Errorf("SelectAWSCategory() for network = %v, want %v", category, "Network")
  769. }
  770. // Test compute category (provider ID with "i-" prefix)
  771. category = SelectAWSCategory("i-123456789", "", "")
  772. if category != "Compute" {
  773. t.Errorf("SelectAWSCategory() for compute = %v, want %v", category, "Compute")
  774. }
  775. // Test GuardDuty special case
  776. category = SelectAWSCategory("i-123456789", "", "AmazonGuardDuty")
  777. if category != "Other" {
  778. t.Errorf("SelectAWSCategory() for GuardDuty = %v, want %v", category, "Other")
  779. }
  780. // Test storage category (provider ID with "vol-" prefix)
  781. category = SelectAWSCategory("vol-123456789", "", "")
  782. if category != "Storage" {
  783. t.Errorf("SelectAWSCategory() for storage = %v, want %v", category, "Storage")
  784. }
  785. // Test service-based categories
  786. category = SelectAWSCategory("", "", "AmazonEKS")
  787. if category != "Management" {
  788. t.Errorf("SelectAWSCategory() for EKS = %v, want %v", category, "Management")
  789. }
  790. // Test fargate pod in EKS
  791. category = SelectAWSCategory("arn:aws:eks:us-west-2:123456789012:pod/cluster-name/pod-name", "", "AmazonEKS")
  792. if category != "Compute" {
  793. t.Errorf("SelectAWSCategory() for EKS fargate pod = %v, want %v", category, "Compute")
  794. }
  795. // Test other category as default
  796. category = SelectAWSCategory("", "", "SomeUnknownService")
  797. if category != "Other" {
  798. t.Errorf("SelectAWSCategory() for unknown service = %v, want %v", category, "Other")
  799. }
  800. }
  801. func TestParseARN(t *testing.T) {
  802. // Test valid ARN
  803. id := "arn:aws:elasticloadbalancing:us-east-1:297945954695:loadbalancer/a406f7761142e4ef58a8f2ba478d2db2"
  804. expected := "a406f7761142e4ef58a8f2ba478d2db2"
  805. result := ParseARN(id)
  806. if result != expected {
  807. t.Errorf("ParseARN() = %v, want %v", result, expected)
  808. }
  809. // Test invalid ARN (no match)
  810. id = "not-an-arn"
  811. result = ParseARN(id)
  812. if result != id {
  813. t.Errorf("ParseARN() for invalid ARN = %v, want %v", result, id)
  814. }
  815. // Test empty string
  816. id = ""
  817. result = ParseARN(id)
  818. if result != id {
  819. t.Errorf("ParseARN() for empty string = %v, want %v", result, id)
  820. }
  821. }