csvprovider.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package cloud
  2. import (
  3. "encoding/csv"
  4. "fmt"
  5. "io"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/aws/aws-sdk-go/aws"
  12. "github.com/aws/aws-sdk-go/aws/session"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. v1 "k8s.io/api/core/v1"
  15. "k8s.io/klog"
  16. "github.com/jszwec/csvutil"
  17. )
  18. const refreshMinutes = 60
  19. type CSVProvider struct {
  20. *CustomProvider
  21. CSVLocation string
  22. Pricing map[string]*price
  23. NodeMapField string
  24. PricingPV map[string]*price
  25. PVMapField string
  26. DownloadPricingDataLock sync.RWMutex
  27. }
  28. type price struct {
  29. EndTimestamp string `csv:"EndTimestamp"`
  30. InstanceID string `csv:"InstanceID"`
  31. AssetClass string `csv:"AssetClass"`
  32. InstanceIDField string `csv:"InstanceIDField"`
  33. InstanceType string `csv:"InstanceType"`
  34. MarketPriceHourly string `csv:"MarketPriceHourly"`
  35. Version string `csv:"Version"`
  36. }
  37. func GetCsv(location string) (io.Reader, error) {
  38. return os.Open(location)
  39. }
  40. func (c *CSVProvider) DownloadPricingData() error {
  41. c.DownloadPricingDataLock.Lock()
  42. defer c.DownloadPricingDataLock.Unlock()
  43. pricing := make(map[string]*price)
  44. pvpricing := make(map[string]*price)
  45. header, err := csvutil.Header(price{}, "csv")
  46. if err != nil {
  47. return err
  48. }
  49. fieldsPerRecord := len(header)
  50. var csvr io.Reader
  51. var csverr error
  52. if strings.HasPrefix(c.CSVLocation, "s3://") {
  53. region := os.Getenv("CSV_REGION")
  54. conf := aws.NewConfig().WithRegion(region).WithCredentialsChainVerboseErrors(true)
  55. s3Client := s3.New(session.New(conf))
  56. bucketAndKey := strings.Split(strings.TrimPrefix(c.CSVLocation, "s3://"), "/")
  57. if len(bucketAndKey) == 2 {
  58. out, err := s3Client.GetObject(&s3.GetObjectInput{
  59. Bucket: aws.String(bucketAndKey[0]),
  60. Key: aws.String(bucketAndKey[1]),
  61. })
  62. csverr = err
  63. csvr = out.Body
  64. } else {
  65. c.Pricing = pricing
  66. c.PricingPV = pvpricing
  67. return fmt.Errorf("Invalid s3 URI: %s", c.CSVLocation)
  68. }
  69. } else {
  70. csvr, csverr = GetCsv(c.CSVLocation)
  71. }
  72. if csverr != nil {
  73. klog.Infof("Error reading csv at %s: %s", c.CSVLocation, csverr)
  74. c.Pricing = pricing
  75. c.PricingPV = pvpricing
  76. return nil
  77. }
  78. csvReader := csv.NewReader(csvr)
  79. csvReader.Comma = ','
  80. csvReader.FieldsPerRecord = fieldsPerRecord
  81. dec, err := csvutil.NewDecoder(csvReader, header...)
  82. if err != nil {
  83. c.Pricing = pricing
  84. c.PricingPV = pvpricing
  85. return err
  86. }
  87. for {
  88. p := price{}
  89. err := dec.Decode(&p)
  90. csvParseErr, isCsvParseErr := err.(*csv.ParseError)
  91. if err == io.EOF {
  92. break
  93. } else if err == csvutil.ErrFieldCount || (isCsvParseErr && csvParseErr.Err == csv.ErrFieldCount) {
  94. rec := dec.Record()
  95. if len(rec) != 1 {
  96. klog.V(2).Infof("Expected %d price info fields but received %d: %s", fieldsPerRecord, len(rec), rec)
  97. continue
  98. }
  99. if strings.Index(rec[0], "#") == 0 {
  100. continue
  101. } else {
  102. klog.V(3).Infof("skipping non-CSV line: %s", rec)
  103. continue
  104. }
  105. } else if err != nil {
  106. klog.V(2).Infof("Error during spot info decode: %+v", err)
  107. continue
  108. }
  109. klog.V(4).Infof("Found price info %+v", p)
  110. if p.AssetClass == "pv" {
  111. pvpricing[p.InstanceID] = &p
  112. c.PVMapField = p.InstanceIDField
  113. } else if p.AssetClass == "node" {
  114. pricing[p.InstanceID] = &p
  115. c.NodeMapField = p.InstanceIDField
  116. } else {
  117. klog.Infof("Unrecognized asset class %s, defaulting to node", p.AssetClass)
  118. pricing[p.InstanceID] = &p
  119. c.NodeMapField = p.InstanceIDField
  120. }
  121. }
  122. if len(pricing) > 0 {
  123. c.Pricing = pricing
  124. c.PricingPV = pvpricing
  125. } else {
  126. klog.Infof("[WARNING] No data received from csv")
  127. }
  128. time.AfterFunc(refreshMinutes*time.Minute, func() { c.DownloadPricingData() })
  129. return nil
  130. }
  131. type csvKey struct {
  132. Labels map[string]string
  133. ProviderID string
  134. }
  135. func (k *csvKey) Features() string {
  136. return ""
  137. }
  138. func (k *csvKey) GPUType() string {
  139. return ""
  140. }
  141. func (k *csvKey) ID() string {
  142. return k.ProviderID
  143. }
  144. func (c *CSVProvider) NodePricing(key Key) (*Node, error) {
  145. c.DownloadPricingDataLock.RLock()
  146. defer c.DownloadPricingDataLock.RUnlock()
  147. if p, ok := c.Pricing[key.ID()]; ok {
  148. return &Node{
  149. Cost: p.MarketPriceHourly,
  150. }, nil
  151. }
  152. return nil, fmt.Errorf("Unable to find Node matching %s", key.ID())
  153. }
  154. func NodeValueFromMapField(m string, n *v1.Node) string {
  155. mf := strings.Split(m, ".")
  156. if len(mf) == 2 && mf[0] == "spec" && mf[1] == "providerID" {
  157. provIdRx := regexp.MustCompile("aws:///([^/]+)/([^/]+)") // It's of the form aws:///us-east-2a/i-0fea4fd46592d050b and we want i-0fea4fd46592d050b, if it exists
  158. for matchNum, group := range provIdRx.FindStringSubmatch(n.Spec.ProviderID) {
  159. if matchNum == 2 {
  160. return group
  161. }
  162. }
  163. if strings.HasPrefix(n.Spec.ProviderID, "azure://") {
  164. vmOrScaleSet := strings.TrimPrefix(n.Spec.ProviderID, "azure://")
  165. return vmOrScaleSet
  166. }
  167. return n.Spec.ProviderID
  168. } else if len(mf) > 1 && mf[0] == "metadata" {
  169. if mf[1] == "name" {
  170. return n.Name
  171. } else if mf[1] == "labels" {
  172. lkey := strings.Join(mf[2:len(mf)], "")
  173. return n.Labels[lkey]
  174. } else if mf[1] == "annotations" {
  175. akey := strings.Join(mf[2:len(mf)], "")
  176. return n.Annotations[akey]
  177. } else {
  178. klog.Infof("[ERROR] Unsupported InstanceIDField %s in CSV For Node", m)
  179. return ""
  180. }
  181. } else {
  182. klog.Infof("[ERROR] Unsupported InstanceIDField %s in CSV For Node", m)
  183. return ""
  184. }
  185. }
  186. func PVValueFromMapField(m string, n *v1.PersistentVolume) string {
  187. mf := strings.Split(m, ".")
  188. if len(mf) > 1 && mf[0] == "metadata" {
  189. if mf[1] == "name" {
  190. return n.Name
  191. } else if mf[1] == "labels" {
  192. lkey := strings.Join(mf[2:len(mf)], "")
  193. return n.Labels[lkey]
  194. } else if mf[1] == "annotations" {
  195. akey := strings.Join(mf[2:len(mf)], "")
  196. return n.Annotations[akey]
  197. } else {
  198. klog.V(4).Infof("[ERROR] Unsupported InstanceIDField %s in CSV For PV", m)
  199. return ""
  200. }
  201. } else {
  202. klog.V(4).Infof("[ERROR] Unsupported InstanceIDField %s in CSV For PV", m)
  203. return ""
  204. }
  205. }
  206. func (c *CSVProvider) GetKey(l map[string]string, n *v1.Node) Key {
  207. id := NodeValueFromMapField(c.NodeMapField, n)
  208. return &csvKey{
  209. ProviderID: id,
  210. Labels: l,
  211. }
  212. }
  213. type csvPVKey struct {
  214. Labels map[string]string
  215. ProviderID string
  216. StorageClassName string
  217. StorageClassParameters map[string]string
  218. Name string
  219. DefaultRegion string
  220. }
  221. func (key *csvPVKey) GetStorageClass() string {
  222. return key.StorageClassName
  223. }
  224. func (key *csvPVKey) Features() string {
  225. return key.ProviderID
  226. }
  227. func (c *CSVProvider) GetPVKey(pv *v1.PersistentVolume, parameters map[string]string, defaultRegion string) PVKey {
  228. id := PVValueFromMapField(c.PVMapField, pv)
  229. return &csvPVKey{
  230. Labels: pv.Labels,
  231. ProviderID: id,
  232. StorageClassName: pv.Spec.StorageClassName,
  233. StorageClassParameters: parameters,
  234. Name: pv.Name,
  235. DefaultRegion: defaultRegion,
  236. }
  237. }
  238. func (c *CSVProvider) PVPricing(pvk PVKey) (*PV, error) {
  239. c.DownloadPricingDataLock.RLock()
  240. defer c.DownloadPricingDataLock.RUnlock()
  241. pricing, ok := c.PricingPV[pvk.Features()]
  242. if !ok {
  243. klog.V(4).Infof("Persistent Volume pricing not found for %s: %s", pvk.GetStorageClass(), pvk.Features())
  244. return &PV{}, nil
  245. }
  246. return &PV{
  247. Cost: pricing.MarketPriceHourly,
  248. }, nil
  249. }