s3selectquerier.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package aws
  2. import (
  3. "context"
  4. "encoding/csv"
  5. "fmt"
  6. "io"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types"
  13. "github.com/opencost/opencost/pkg/cloud"
  14. "github.com/opencost/opencost/pkg/cloud/config"
  15. "github.com/opencost/opencost/pkg/util/stringutil"
  16. )
  17. type S3SelectQuerier struct {
  18. S3Connection
  19. connectionStatus cloud.ConnectionStatus
  20. }
  21. func (s3sq *S3SelectQuerier) Equals(config config.Config) bool {
  22. thatConfig, ok := config.(*S3SelectQuerier)
  23. if !ok {
  24. return false
  25. }
  26. return s3sq.S3Connection.Equals(&thatConfig.S3Connection)
  27. }
  28. func (s3sq *S3SelectQuerier) Query(query string, queryKeys []string, cli *s3.Client, fn func(*csv.Reader) error) error {
  29. for _, queryKey := range queryKeys {
  30. reader, err2 := s3sq.fetchCSVReader(query, queryKey, cli, s3Types.FileHeaderInfoUse)
  31. if err2 != nil {
  32. return err2
  33. }
  34. err2 = fn(reader)
  35. if err2 != nil {
  36. return err2
  37. }
  38. }
  39. return nil
  40. }
  41. // GetQueryKeys returns a list of s3 object names, where the there are 1 object for each month within the range between
  42. // start and end
  43. func (s3sq *S3SelectQuerier) GetQueryKeys(start, end time.Time, client *s3.Client) ([]string, error) {
  44. objs, err := s3sq.ListObjects(client)
  45. if err != nil {
  46. return nil, err
  47. }
  48. monthStrings, err := getMonthStrings(start, end)
  49. if err != err {
  50. return nil, err
  51. }
  52. var queryKeys []string
  53. // Find all matching "csv.gz" files per monthString
  54. for _, monthStr := range monthStrings {
  55. for _, obj := range objs.Contents {
  56. if strings.Contains(*obj.Key, monthStr) && strings.HasSuffix(*obj.Key, ".csv.gz") {
  57. queryKeys = append(queryKeys, *obj.Key)
  58. }
  59. }
  60. }
  61. if len(queryKeys) == 0 {
  62. return nil, fmt.Errorf("no CUR files for given time range")
  63. }
  64. return queryKeys, nil
  65. }
  66. func (s3sq *S3SelectQuerier) fetchCSVReader(query string, queryKey string, client *s3.Client, fileHeaderInfo s3Types.FileHeaderInfo) (*csv.Reader, error) {
  67. input := &s3.SelectObjectContentInput{
  68. Bucket: aws.String(s3sq.Bucket),
  69. Key: aws.String(queryKey),
  70. Expression: aws.String(query),
  71. ExpressionType: s3Types.ExpressionTypeSql,
  72. InputSerialization: &s3Types.InputSerialization{
  73. CompressionType: s3Types.CompressionTypeGzip,
  74. CSV: &s3Types.CSVInput{
  75. FileHeaderInfo: fileHeaderInfo,
  76. },
  77. },
  78. OutputSerialization: &s3Types.OutputSerialization{
  79. CSV: &s3Types.CSVOutput{},
  80. },
  81. }
  82. res, err := client.SelectObjectContent(context.TODO(), input)
  83. if err != nil {
  84. return nil, err
  85. }
  86. resStream := res.GetStream()
  87. // todo: this needs work
  88. results, resultWriter := io.Pipe()
  89. go func() {
  90. defer resultWriter.Close()
  91. defer resStream.Close()
  92. resStream.Events()
  93. for event := range resStream.Events() {
  94. switch e := event.(type) {
  95. case *s3Types.SelectObjectContentEventStreamMemberRecords:
  96. resultWriter.Write(e.Value.Payload)
  97. case *s3Types.SelectObjectContentEventStreamMemberEnd:
  98. break
  99. }
  100. }
  101. }()
  102. if err := resStream.Err(); err != nil {
  103. return nil, fmt.Errorf("failed to read from SelectObjectContent EventStream, %v", err)
  104. }
  105. return csv.NewReader(results), nil
  106. }
  107. func getMonthStrings(start, end time.Time) ([]string, error) {
  108. if start.After(end) {
  109. return []string{}, fmt.Errorf("start date must be before end date")
  110. }
  111. if end.After(time.Now()) {
  112. end = time.Now()
  113. }
  114. dateTemplate := "%d%02d01-%d%02d01/"
  115. // set to first of the month
  116. currMonth := start.AddDate(0, 0, -start.Day()+1)
  117. nextMonth := currMonth.AddDate(0, 1, 0)
  118. monthStr := fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  119. // Create string for end condition
  120. endMonth := end.AddDate(0, 0, -end.Day()+1)
  121. endNextMonth := endMonth.AddDate(0, 1, 0)
  122. endStr := fmt.Sprintf(dateTemplate, endMonth.Year(), int(endMonth.Month()), endNextMonth.Year(), int(endNextMonth.Month()))
  123. var monthStrs []string
  124. monthStrs = append(monthStrs, monthStr)
  125. for monthStr != endStr {
  126. currMonth = nextMonth
  127. nextMonth = nextMonth.AddDate(0, 1, 0)
  128. monthStr = fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  129. monthStrs = append(monthStrs, monthStr)
  130. }
  131. return monthStrs, nil
  132. }
  133. // GetCSVRowValue retrieve value from athena row based on column names and used stringutil.Bank() to prevent duplicate
  134. // allocation of strings
  135. func GetCSVRowValue(row []string, queryColumnIndexes map[string]int, columnName string) string {
  136. if row == nil {
  137. return ""
  138. }
  139. columnIndex, ok := queryColumnIndexes[columnName]
  140. if !ok {
  141. return ""
  142. }
  143. return stringutil.Bank(row[columnIndex])
  144. }
  145. // GetCSVRowValueFloat retrieve value from athena row based on column names and convert to float if possible.
  146. func GetCSVRowValueFloat(row []string, queryColumnIndexes map[string]int, columnName string) (float64, error) {
  147. if row == nil {
  148. return 0.0, fmt.Errorf("getCSVRowValueFloat: nil row")
  149. }
  150. columnIndex, ok := queryColumnIndexes[columnName]
  151. if !ok {
  152. return 0.0, fmt.Errorf("getCSVRowValueFloat: missing column index: %s", columnName)
  153. }
  154. cost, err := strconv.ParseFloat(row[columnIndex], 64)
  155. if err != nil {
  156. return cost, fmt.Errorf("getCSVRowValueFloat: failed to parse %s: '%s': %s", columnName, row[columnIndex], err.Error())
  157. }
  158. return cost, nil
  159. }