s3selectquerier.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package aws
  2. import (
  3. "context"
  4. "encoding/csv"
  5. "fmt"
  6. "io"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types"
  13. "github.com/opencost/opencost/core/pkg/util/stringutil"
  14. "github.com/opencost/opencost/pkg/cloud"
  15. )
  16. type S3SelectQuerier struct {
  17. S3Connection
  18. connectionStatus cloud.ConnectionStatus
  19. }
  20. func (s3sq *S3SelectQuerier) Equals(config cloud.Config) bool {
  21. thatConfig, ok := config.(*S3SelectQuerier)
  22. if !ok {
  23. return false
  24. }
  25. return s3sq.S3Connection.Equals(&thatConfig.S3Connection)
  26. }
  27. func (s3sq *S3SelectQuerier) Query(query string, queryKeys []string, cli *s3.Client, fn func(*csv.Reader) error) error {
  28. for _, queryKey := range queryKeys {
  29. reader, err2 := s3sq.fetchCSVReader(query, queryKey, cli, s3Types.FileHeaderInfoUse)
  30. if err2 != nil {
  31. return err2
  32. }
  33. err2 = fn(reader)
  34. if err2 != nil {
  35. return err2
  36. }
  37. }
  38. return nil
  39. }
  40. // GetQueryKeys returns a list of s3 object names, where the there are 1 object for each month within the range between
  41. // start and end
  42. func (s3sq *S3SelectQuerier) GetQueryKeys(start, end time.Time, client *s3.Client) ([]string, error) {
  43. objs, err := s3sq.ListObjects(client)
  44. if err != nil {
  45. return nil, err
  46. }
  47. monthStrings, err := getMonthStrings(start, end)
  48. if err != err {
  49. return nil, err
  50. }
  51. var queryKeys []string
  52. // Find all matching "csv.gz" files per monthString
  53. for _, monthStr := range monthStrings {
  54. for _, obj := range objs.Contents {
  55. if strings.Contains(*obj.Key, monthStr) && strings.HasSuffix(*obj.Key, ".csv.gz") {
  56. queryKeys = append(queryKeys, *obj.Key)
  57. }
  58. }
  59. }
  60. if len(queryKeys) == 0 {
  61. return nil, fmt.Errorf("no CUR files for given time range")
  62. }
  63. return queryKeys, nil
  64. }
  65. func (s3sq *S3SelectQuerier) fetchCSVReader(query string, queryKey string, client *s3.Client, fileHeaderInfo s3Types.FileHeaderInfo) (*csv.Reader, error) {
  66. input := &s3.SelectObjectContentInput{
  67. Bucket: aws.String(s3sq.Bucket),
  68. Key: aws.String(queryKey),
  69. Expression: aws.String(query),
  70. ExpressionType: s3Types.ExpressionTypeSql,
  71. InputSerialization: &s3Types.InputSerialization{
  72. CompressionType: s3Types.CompressionTypeGzip,
  73. CSV: &s3Types.CSVInput{
  74. FileHeaderInfo: fileHeaderInfo,
  75. },
  76. },
  77. OutputSerialization: &s3Types.OutputSerialization{
  78. CSV: &s3Types.CSVOutput{},
  79. },
  80. }
  81. res, err := client.SelectObjectContent(context.TODO(), input)
  82. if err != nil {
  83. return nil, err
  84. }
  85. resStream := res.GetStream()
  86. // todo: this needs work
  87. results, resultWriter := io.Pipe()
  88. go func() {
  89. defer resultWriter.Close()
  90. defer resStream.Close()
  91. resStream.Events()
  92. for event := range resStream.Events() {
  93. switch e := event.(type) {
  94. case *s3Types.SelectObjectContentEventStreamMemberRecords:
  95. resultWriter.Write(e.Value.Payload)
  96. case *s3Types.SelectObjectContentEventStreamMemberEnd:
  97. break
  98. }
  99. }
  100. }()
  101. if err := resStream.Err(); err != nil {
  102. return nil, fmt.Errorf("failed to read from SelectObjectContent EventStream, %v", err)
  103. }
  104. return csv.NewReader(results), nil
  105. }
  106. func getMonthStrings(start, end time.Time) ([]string, error) {
  107. if start.After(end) {
  108. return []string{}, fmt.Errorf("start date must be before end date")
  109. }
  110. if end.After(time.Now()) {
  111. end = time.Now()
  112. }
  113. dateTemplate := "%d%02d01-%d%02d01/"
  114. // set to first of the month
  115. currMonth := start.AddDate(0, 0, -start.Day()+1)
  116. nextMonth := currMonth.AddDate(0, 1, 0)
  117. monthStr := fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  118. // Create string for end condition
  119. endMonth := end.AddDate(0, 0, -end.Day()+1)
  120. endNextMonth := endMonth.AddDate(0, 1, 0)
  121. endStr := fmt.Sprintf(dateTemplate, endMonth.Year(), int(endMonth.Month()), endNextMonth.Year(), int(endNextMonth.Month()))
  122. var monthStrs []string
  123. monthStrs = append(monthStrs, monthStr)
  124. for monthStr != endStr {
  125. currMonth = nextMonth
  126. nextMonth = nextMonth.AddDate(0, 1, 0)
  127. monthStr = fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  128. monthStrs = append(monthStrs, monthStr)
  129. }
  130. return monthStrs, nil
  131. }
  132. // GetCSVRowValue retrieve value from athena row based on column names and used stringutil.Bank() to prevent duplicate
  133. // allocation of strings
  134. func GetCSVRowValue(row []string, queryColumnIndexes map[string]int, columnName string) string {
  135. if row == nil {
  136. return ""
  137. }
  138. columnIndex, ok := queryColumnIndexes[columnName]
  139. if !ok {
  140. return ""
  141. }
  142. return stringutil.Bank(row[columnIndex])
  143. }
  144. // GetCSVRowValueFloat retrieve value from athena row based on column names and convert to float if possible.
  145. func GetCSVRowValueFloat(row []string, queryColumnIndexes map[string]int, columnName string) (float64, error) {
  146. if row == nil {
  147. return 0.0, fmt.Errorf("getCSVRowValueFloat: nil row")
  148. }
  149. columnIndex, ok := queryColumnIndexes[columnName]
  150. if !ok {
  151. return 0.0, fmt.Errorf("getCSVRowValueFloat: missing column index: %s", columnName)
  152. }
  153. cost, err := strconv.ParseFloat(row[columnIndex], 64)
  154. if err != nil {
  155. return cost, fmt.Errorf("getCSVRowValueFloat: failed to parse %s: '%s': %s", columnName, row[columnIndex], err.Error())
  156. }
  157. return cost, nil
  158. }