athenaquerier_mock.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package aws
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "github.com/aws/aws-sdk-go-v2/aws"
  7. "github.com/aws/aws-sdk-go-v2/service/athena"
  8. "github.com/aws/aws-sdk-go-v2/service/athena/types"
  9. "github.com/opencost/opencost/pkg/cloud"
  10. )
  11. // MockAthenaClient is a mock implementation of the Athena client for testing
  12. type MockAthenaClient struct {
  13. StartQueryExecutionFunc func(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error)
  14. GetQueryExecutionFunc func(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error)
  15. GetQueryResultsFunc func(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error)
  16. }
  17. func (m *MockAthenaClient) StartQueryExecution(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) {
  18. if m.StartQueryExecutionFunc != nil {
  19. return m.StartQueryExecutionFunc(ctx, params, optFns...)
  20. }
  21. return &athena.StartQueryExecutionOutput{
  22. QueryExecutionId: aws.String("mock-query-id-123"),
  23. }, nil
  24. }
  25. func (m *MockAthenaClient) GetQueryExecution(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) {
  26. if m.GetQueryExecutionFunc != nil {
  27. return m.GetQueryExecutionFunc(ctx, params, optFns...)
  28. }
  29. return &athena.GetQueryExecutionOutput{
  30. QueryExecution: &types.QueryExecution{
  31. Status: &types.QueryExecutionStatus{
  32. State: types.QueryExecutionStateSucceeded,
  33. },
  34. },
  35. }, nil
  36. }
  37. func (m *MockAthenaClient) GetQueryResults(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) {
  38. if m.GetQueryResultsFunc != nil {
  39. return m.GetQueryResultsFunc(ctx, params, optFns...)
  40. }
  41. return &athena.GetQueryResultsOutput{
  42. ResultSet: &types.ResultSet{
  43. Rows: []types.Row{
  44. {Data: []types.Datum{}},
  45. },
  46. },
  47. }, nil
  48. }
  49. // MockAthenaQuerier wraps AthenaQuerier with a mock client for testing
  50. type MockAthenaQuerier struct {
  51. AthenaQuerier
  52. mockClient *MockAthenaClient
  53. }
  54. // FailingMockAthenaQuerier is a mock querier that fails on GetAthenaClient
  55. type FailingMockAthenaQuerier struct {
  56. MockAthenaQuerier
  57. }
  58. func (fmaq *FailingMockAthenaQuerier) GetAthenaClient() (*athena.Client, error) {
  59. return nil, errors.New("failed to create client")
  60. }
  61. // FailingQueryAthenaQuerier is a mock querier that fails in queryAthenaPaginated
  62. type FailingQueryAthenaQuerier struct {
  63. MockAthenaQuerier
  64. }
  65. func (fqaq *FailingQueryAthenaQuerier) GetAthenaClient() (*athena.Client, error) {
  66. return nil, errors.New("failed to create client")
  67. }
  68. func (fqaq *FailingQueryAthenaQuerier) queryAthenaPaginated(ctx context.Context, query string, fn func(*athena.GetQueryResultsOutput) bool) error {
  69. // Simulate GetAthenaClient failure
  70. _, err := fqaq.GetAthenaClient()
  71. if err != nil {
  72. return fmt.Errorf("QueryAthenaPaginated: GetAthenaClient error: %s", err.Error())
  73. }
  74. // Check if context is cancelled
  75. select {
  76. case <-ctx.Done():
  77. return ctx.Err()
  78. default:
  79. }
  80. // Acknowledge the query parameter to avoid unused parameter warning
  81. _ = query
  82. // Call the function with empty result to simulate no data
  83. fn(&athena.GetQueryResultsOutput{})
  84. return nil
  85. }
  86. func (maq *MockAthenaQuerier) GetAthenaClient() (*athena.Client, error) {
  87. // Return a real client but we'll override the methods in tests
  88. cfg, err := maq.Authorizer.CreateAWSConfig(maq.Region)
  89. if err != nil {
  90. return nil, err
  91. }
  92. cli := athena.NewFromConfig(cfg)
  93. return cli, nil
  94. }
  95. func (maq *MockAthenaQuerier) GetColumns() (map[string]bool, error) {
  96. columnSet := map[string]bool{}
  97. // This Query is supported by Athena tables and views
  98. q := `SELECT column_name FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'`
  99. query := fmt.Sprintf(q, maq.Database, maq.Table)
  100. athenaErr := maq.Query(context.TODO(), query, GetAthenaQueryFunc(func(row types.Row) {
  101. columnSet[*row.Data[0].VarCharValue] = true
  102. }))
  103. if athenaErr != nil {
  104. return columnSet, athenaErr
  105. }
  106. if len(columnSet) == 0 {
  107. // Don't log in tests
  108. }
  109. return columnSet, nil
  110. }
  111. func (maq *MockAthenaQuerier) Query(ctx context.Context, query string, fn func(*athena.GetQueryResultsOutput) bool) error {
  112. err := maq.Validate()
  113. if err != nil {
  114. maq.ConnectionStatus = cloud.InvalidConfiguration
  115. return err
  116. }
  117. // Use mock client instead of real one
  118. queryExecutionCtx := &types.QueryExecutionContext{
  119. Database: aws.String(maq.Database),
  120. }
  121. if maq.Catalog != "" {
  122. queryExecutionCtx.Catalog = aws.String(maq.Catalog)
  123. }
  124. resultConfiguration := &types.ResultConfiguration{
  125. OutputLocation: aws.String(maq.Bucket),
  126. }
  127. startQueryExecutionInput := &athena.StartQueryExecutionInput{
  128. QueryString: aws.String(query),
  129. QueryExecutionContext: queryExecutionCtx,
  130. ResultConfiguration: resultConfiguration,
  131. }
  132. if maq.Workgroup != "" {
  133. startQueryExecutionInput.WorkGroup = aws.String(maq.Workgroup)
  134. }
  135. // Use mock client
  136. startQueryExecutionOutput, err := maq.mockClient.StartQueryExecution(ctx, startQueryExecutionInput)
  137. if err != nil {
  138. maq.ConnectionStatus = cloud.FailedConnection
  139. return fmt.Errorf("QueryAthenaPaginated: start query error: %s", err.Error())
  140. }
  141. err = maq.waitForQueryToComplete(ctx, maq.mockClient, startQueryExecutionOutput.QueryExecutionId)
  142. if err != nil {
  143. maq.ConnectionStatus = cloud.FailedConnection
  144. return fmt.Errorf("QueryAthenaPaginated: query execution error: %s", err.Error())
  145. }
  146. queryResultsInput := &athena.GetQueryResultsInput{
  147. QueryExecutionId: startQueryExecutionOutput.QueryExecutionId,
  148. MaxResults: aws.Int32(1000),
  149. }
  150. // Simulate pagination
  151. pg, err := maq.mockClient.GetQueryResults(ctx, queryResultsInput)
  152. if err != nil {
  153. maq.ConnectionStatus = cloud.FailedConnection
  154. return err
  155. }
  156. fn(pg)
  157. maq.ConnectionStatus = cloud.SuccessfulConnection
  158. return nil
  159. }
  160. func (maq *MockAthenaQuerier) queryAthenaPaginated(ctx context.Context, query string, fn func(*athena.GetQueryResultsOutput) bool) error {
  161. queryExecutionCtx := &types.QueryExecutionContext{
  162. Database: aws.String(maq.Database),
  163. }
  164. if maq.Catalog != "" {
  165. queryExecutionCtx.Catalog = aws.String(maq.Catalog)
  166. }
  167. resultConfiguration := &types.ResultConfiguration{
  168. OutputLocation: aws.String(maq.Bucket),
  169. }
  170. startQueryExecutionInput := &athena.StartQueryExecutionInput{
  171. QueryString: aws.String(query),
  172. QueryExecutionContext: queryExecutionCtx,
  173. ResultConfiguration: resultConfiguration,
  174. }
  175. if maq.Workgroup != "" {
  176. startQueryExecutionInput.WorkGroup = aws.String(maq.Workgroup)
  177. }
  178. // Use mock client
  179. startQueryExecutionOutput, err := maq.mockClient.StartQueryExecution(ctx, startQueryExecutionInput)
  180. if err != nil {
  181. return fmt.Errorf("QueryAthenaPaginated: start query error: %s", err.Error())
  182. }
  183. err = maq.waitForQueryToComplete(ctx, maq.mockClient, startQueryExecutionOutput.QueryExecutionId)
  184. if err != nil {
  185. return fmt.Errorf("QueryAthenaPaginated: query execution error: %s", err.Error())
  186. }
  187. queryResultsInput := &athena.GetQueryResultsInput{
  188. QueryExecutionId: startQueryExecutionOutput.QueryExecutionId,
  189. MaxResults: aws.Int32(1000),
  190. }
  191. // Simulate pagination
  192. pg, err := maq.mockClient.GetQueryResults(ctx, queryResultsInput)
  193. if err != nil {
  194. return err
  195. }
  196. fn(pg)
  197. return nil
  198. }
  199. func (maq *MockAthenaQuerier) waitForQueryToComplete(ctx context.Context, client *MockAthenaClient, queryExecutionID *string) error {
  200. if queryExecutionID == nil {
  201. return fmt.Errorf("query execution ID is nil")
  202. }
  203. inp := &athena.GetQueryExecutionInput{
  204. QueryExecutionId: queryExecutionID,
  205. }
  206. // Simulate waiting with mock
  207. qe, err := client.GetQueryExecution(ctx, inp)
  208. if err != nil {
  209. return err
  210. }
  211. if qe.QueryExecution.Status.State != "SUCCEEDED" {
  212. return fmt.Errorf("no query results available for query %s", *queryExecutionID)
  213. }
  214. return nil
  215. }
  216. // Helper function to create string pointers
  217. func stringPtr(s string) *string {
  218. return &s
  219. }