s3selectquerier.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package aws
  2. import (
  3. "context"
  4. "encoding/csv"
  5. "fmt"
  6. "io"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/aws/aws-sdk-go-v2/aws"
  11. "github.com/aws/aws-sdk-go-v2/service/s3"
  12. s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types"
  13. "github.com/opencost/opencost/core/pkg/util/stringutil"
  14. "github.com/opencost/opencost/pkg/cloud"
  15. )
  16. type S3SelectQuerier struct {
  17. S3Connection
  18. connectionStatus cloud.ConnectionStatus
  19. }
  20. func (s3sq *S3SelectQuerier) Equals(config cloud.Config) bool {
  21. thatConfig, ok := config.(*S3SelectQuerier)
  22. if !ok {
  23. return false
  24. }
  25. return s3sq.S3Connection.Equals(&thatConfig.S3Connection)
  26. }
  27. func (s3sq *S3SelectQuerier) Query(query string, queryKeys []string, cli *s3.Client, fn func(*csv.Reader) error) error {
  28. for _, queryKey := range queryKeys {
  29. reader, err2 := s3sq.fetchCSVReader(query, queryKey, cli, s3Types.FileHeaderInfoUse)
  30. if err2 != nil {
  31. return err2
  32. }
  33. err2 = fn(reader)
  34. if err2 != nil {
  35. return err2
  36. }
  37. }
  38. return nil
  39. }
  40. func (s3sq *S3SelectQuerier) GetHeaders(queryKey string, cli *s3.Client) ([]string, error) {
  41. reader, err := s3sq.fetchCSVReader("SELECT * FROM S3Object LIMIT 1", queryKey, cli, s3Types.FileHeaderInfoNone)
  42. if err != nil {
  43. return nil, err
  44. }
  45. record, err := reader.Read()
  46. if err != nil {
  47. return nil, err
  48. }
  49. return record, nil
  50. }
  51. // GetQueryKeys returns a list of s3 object names, where the there are 1 object for each month within the range between
  52. // start and end
  53. func (s3sq *S3SelectQuerier) GetQueryKeys(start, end time.Time, client *s3.Client) ([]string, error) {
  54. objs, err := s3sq.ListObjects(client)
  55. if err != nil {
  56. return nil, err
  57. }
  58. monthStrings, err := getMonthStrings(start, end)
  59. if err != err {
  60. return nil, err
  61. }
  62. var queryKeys []string
  63. // Find all matching "csv.gz" files per monthString
  64. for _, monthStr := range monthStrings {
  65. for _, obj := range objs.Contents {
  66. if strings.Contains(*obj.Key, monthStr) && strings.HasSuffix(*obj.Key, ".csv.gz") {
  67. queryKeys = append(queryKeys, *obj.Key)
  68. }
  69. }
  70. }
  71. if len(queryKeys) == 0 {
  72. return nil, fmt.Errorf("no CUR files for given time range")
  73. }
  74. return queryKeys, nil
  75. }
  76. func (s3sq *S3SelectQuerier) fetchCSVReader(query string, queryKey string, client *s3.Client, fileHeaderInfo s3Types.FileHeaderInfo) (*csv.Reader, error) {
  77. input := &s3.SelectObjectContentInput{
  78. Bucket: aws.String(s3sq.Bucket),
  79. Key: aws.String(queryKey),
  80. Expression: aws.String(query),
  81. ExpressionType: s3Types.ExpressionTypeSql,
  82. InputSerialization: &s3Types.InputSerialization{
  83. CompressionType: s3Types.CompressionTypeGzip,
  84. CSV: &s3Types.CSVInput{
  85. FileHeaderInfo: fileHeaderInfo,
  86. },
  87. },
  88. OutputSerialization: &s3Types.OutputSerialization{
  89. CSV: &s3Types.CSVOutput{},
  90. },
  91. }
  92. res, err := client.SelectObjectContent(context.TODO(), input)
  93. if err != nil {
  94. return nil, err
  95. }
  96. resStream := res.GetStream()
  97. // todo: this needs work
  98. results, resultWriter := io.Pipe()
  99. go func() {
  100. defer resultWriter.Close()
  101. defer resStream.Close()
  102. resStream.Events()
  103. for event := range resStream.Events() {
  104. switch e := event.(type) {
  105. case *s3Types.SelectObjectContentEventStreamMemberRecords:
  106. resultWriter.Write(e.Value.Payload)
  107. case *s3Types.SelectObjectContentEventStreamMemberEnd:
  108. break
  109. }
  110. }
  111. }()
  112. if err := resStream.Err(); err != nil {
  113. return nil, fmt.Errorf("failed to read from SelectObjectContent EventStream, %v", err)
  114. }
  115. return csv.NewReader(results), nil
  116. }
  117. func getMonthStrings(start, end time.Time) ([]string, error) {
  118. if start.After(end) {
  119. return []string{}, fmt.Errorf("start date must be before end date")
  120. }
  121. if end.After(time.Now()) {
  122. end = time.Now()
  123. }
  124. dateTemplate := "%d%02d01-%d%02d01/"
  125. // set to first of the month
  126. currMonth := start.AddDate(0, 0, -start.Day()+1)
  127. nextMonth := currMonth.AddDate(0, 1, 0)
  128. monthStr := fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  129. // Create string for end condition
  130. endMonth := end.AddDate(0, 0, -end.Day()+1)
  131. endNextMonth := endMonth.AddDate(0, 1, 0)
  132. endStr := fmt.Sprintf(dateTemplate, endMonth.Year(), int(endMonth.Month()), endNextMonth.Year(), int(endNextMonth.Month()))
  133. var monthStrs []string
  134. monthStrs = append(monthStrs, monthStr)
  135. for monthStr != endStr {
  136. currMonth = nextMonth
  137. nextMonth = nextMonth.AddDate(0, 1, 0)
  138. monthStr = fmt.Sprintf(dateTemplate, currMonth.Year(), int(currMonth.Month()), nextMonth.Year(), int(nextMonth.Month()))
  139. monthStrs = append(monthStrs, monthStr)
  140. }
  141. return monthStrs, nil
  142. }
  143. // GetCSVRowValue retrieve value from athena row based on column names and used stringutil.Bank() to prevent duplicate
  144. // allocation of strings
  145. func GetCSVRowValue(row []string, queryColumnIndexes map[string]int, columnName string) string {
  146. if row == nil {
  147. return ""
  148. }
  149. columnIndex, ok := queryColumnIndexes[columnName]
  150. if !ok {
  151. return ""
  152. }
  153. return stringutil.Bank(row[columnIndex])
  154. }
  155. // GetCSVRowValueFloat retrieve value from athena row based on column names and convert to float if possible.
  156. func GetCSVRowValueFloat(row []string, queryColumnIndexes map[string]int, columnName string) (float64, error) {
  157. if row == nil {
  158. return 0.0, fmt.Errorf("getCSVRowValueFloat: nil row")
  159. }
  160. columnIndex, ok := queryColumnIndexes[columnName]
  161. if !ok {
  162. return 0.0, fmt.Errorf("getCSVRowValueFloat: missing column index: %s", columnName)
  163. }
  164. cost, err := strconv.ParseFloat(row[columnIndex], 64)
  165. if err != nil {
  166. return cost, fmt.Errorf("getCSVRowValueFloat: failed to parse %s: '%s': %s", columnName, row[columnIndex], err.Error())
  167. }
  168. return cost, nil
  169. }