s3selectquerier.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package aws
  2. import (
  3. "context"
  4. "encoding/csv"
  5. "fmt"
  6. "io"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types"
  13. "github.com/opencost/opencost/pkg/cloud/config"
  14. "github.com/opencost/opencost/pkg/util/stringutil"
  15. )
  16. type S3SelectQuerier struct {
  17. S3Connection
  18. }
  19. func (s3sq *S3SelectQuerier) Equals(config config.Config) bool {
  20. thatConfig, ok := config.(*S3SelectQuerier)
  21. if !ok {
  22. return false
  23. }
  24. return s3sq.S3Connection.Equals(&thatConfig.S3Connection)
  25. }
  26. func (s3sq *S3SelectQuerier) Query(query string, queryKeys []string, cli *s3.Client, fn func(*csv.Reader) error) error {
  27. for _, queryKey := range queryKeys {
  28. reader, err2 := s3sq.fetchCSVReader(query, queryKey, cli, s3Types.FileHeaderInfoUse)
  29. if err2 != nil {
  30. return err2
  31. }
  32. err2 = fn(reader)
  33. if err2 != nil {
  34. return err2
  35. }
  36. }
  37. return nil
  38. }
  39. // GetQueryKeys returns a list of s3 object names, where the there are 1 object for each month within the range between
  40. // start and end
  41. func (s3sq *S3SelectQuerier) GetQueryKeys(start, end time.Time, client *s3.Client) ([]string, error) {
  42. objs, err := s3sq.ListObjects(client)
  43. if err != nil {
  44. return nil, err
  45. }
  46. monthStrings, err := getMonthStrings(start, end)
  47. if err != err {
  48. return nil, err
  49. }
  50. var queryKeys []string
  51. // Find all matching "csv.gz" files per monthString
  52. for _, monthStr := range monthStrings {
  53. for _, obj := range objs.Contents {
  54. if strings.Contains(*obj.Key, monthStr) && strings.HasSuffix(*obj.Key, ".csv.gz") {
  55. queryKeys = append(queryKeys, *obj.Key)
  56. }
  57. }
  58. }
  59. if len(queryKeys) == 0 {
  60. return nil, fmt.Errorf("no CUR files for given time range")
  61. }
  62. return queryKeys, nil
  63. }
  64. func (s3sq *S3SelectQuerier) fetchCSVReader(query string, queryKey string, client *s3.Client, fileHeaderInfo s3Types.FileHeaderInfo) (*csv.Reader, error) {
  65. input := &s3.SelectObjectContentInput{
  66. Bucket: aws.String(s3sq.Bucket),
  67. Key: aws.String(queryKey),
  68. Expression: aws.String(query),
  69. ExpressionType: s3Types.ExpressionTypeSql,
  70. InputSerialization: &s3Types.InputSerialization{
  71. CompressionType: s3Types.CompressionTypeGzip,
  72. CSV: &s3Types.CSVInput{
  73. FileHeaderInfo: fileHeaderInfo,
  74. },
  75. },
  76. OutputSerialization: &s3Types.OutputSerialization{
  77. CSV: &s3Types.CSVOutput{},
  78. },
  79. }
  80. res, err := client.SelectObjectContent(context.TODO(), input)
  81. if err != nil {
  82. return nil, err
  83. }
  84. resStream := res.GetStream()
  85. // todo: this needs work
  86. results, resultWriter := io.Pipe()
  87. go func() {
  88. defer resultWriter.Close()
  89. defer resStream.Close()
  90. resStream.Events()
  91. for event := range resStream.Events() {
  92. switch e := event.(type) {
  93. case *s3Types.SelectObjectContentEventStreamMemberRecords:
  94. resultWriter.Write(e.Value.Payload)
  95. case *s3Types.SelectObjectContentEventStreamMemberEnd:
  96. break
  97. }
  98. }
  99. }()
  100. if err := resStream.Err(); err != nil {
  101. return nil, fmt.Errorf("failed to read from SelectObjectContent EventStream, %v", err)
  102. }
  103. return csv.NewReader(results), nil
  104. }
  105. func getMonthStrings(start, end time.Time) ([]string, error) {
  106. if start.After(end) {
  107. return []string{}, fmt.Errorf("start date must be before end date")
  108. }
  109. if end.After(time.Now()) {
  110. end = time.Now()
  111. }
  112. dateTemplate := "%d%02d01-%d%02d01/"
  113. // set to first of the month
  114. currMonth := start.AddDate(0, 0, -start.Day()+1)
  115. nextMonth := currMonth.AddDate(0, 1, 0)
  116. monthStr := fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  117. // Create string for end condition
  118. endMonth := end.AddDate(0, 0, -end.Day()+1)
  119. endNextMonth := endMonth.AddDate(0, 1, 0)
  120. endStr := fmt.Sprintf(dateTemplate, endMonth.Year(), int(endMonth.Month()), endNextMonth.Year(), int(endNextMonth.Month()))
  121. var monthStrs []string
  122. monthStrs = append(monthStrs, monthStr)
  123. for monthStr != endStr {
  124. currMonth = nextMonth
  125. nextMonth = nextMonth.AddDate(0, 1, 0)
  126. monthStr = fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  127. monthStrs = append(monthStrs, monthStr)
  128. }
  129. return monthStrs, nil
  130. }
  131. // GetCSVRowValue retrieve value from athena row based on column names and used stringutil.Bank() to prevent duplicate
  132. // allocation of strings
  133. func GetCSVRowValue(row []string, queryColumnIndexes map[string]int, columnName string) string {
  134. if row == nil {
  135. return ""
  136. }
  137. columnIndex, ok := queryColumnIndexes[columnName]
  138. if !ok {
  139. return ""
  140. }
  141. return stringutil.Bank(row[columnIndex])
  142. }
  143. // GetCSVRowValueFloat retrieve value from athena row based on column names and convert to float if possible.
  144. func GetCSVRowValueFloat(row []string, queryColumnIndexes map[string]int, columnName string) (float64, error) {
  145. if row == nil {
  146. return 0.0, fmt.Errorf("getCSVRowValueFloat: nil row")
  147. }
  148. columnIndex, ok := queryColumnIndexes[columnName]
  149. if !ok {
  150. return 0.0, fmt.Errorf("getCSVRowValueFloat: missing column index: %s", columnName)
  151. }
  152. cost, err := strconv.ParseFloat(row[columnIndex], 64)
  153. if err != nil {
  154. return cost, fmt.Errorf("getCSVRowValueFloat: failed to parse %s: '%s': %s", columnName, row[columnIndex], err.Error())
  155. }
  156. return cost, nil
  157. }