2
0

csvprovider.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package cloud
  2. import (
  3. "encoding/csv"
  4. "fmt"
  5. "io"
  6. "os"
  7. "strings"
  8. "sync"
  9. "time"
  10. "github.com/aws/aws-sdk-go/aws"
  11. "github.com/aws/aws-sdk-go/aws/session"
  12. "github.com/aws/aws-sdk-go/service/s3"
  13. v1 "k8s.io/api/core/v1"
  14. "k8s.io/klog"
  15. "github.com/jszwec/csvutil"
  16. )
  17. const refreshMinutes = 60
  18. type CSVProvider struct {
  19. *CustomProvider
  20. CSVLocation string
  21. Pricing map[string]*price
  22. NodeMapField string
  23. PricingPV map[string]*price
  24. PVMapField string
  25. DownloadPricingDataLock sync.RWMutex
  26. }
  27. type price struct {
  28. EndTimestamp string `csv:"EndTimestamp"`
  29. InstanceID string `csv:"InstanceID"`
  30. AssetClass string `csv:"AssetClass"`
  31. InstanceIDField string `csv:"InstanceIDField"`
  32. InstanceType string `csv:"InstanceType"`
  33. MarketPriceHourly string `csv:"MarketPriceHourly"`
  34. Version string `csv:"Version"`
  35. }
  36. func GetCsv(location string) (io.Reader, error) {
  37. return os.Open(location)
  38. }
  39. func (c *CSVProvider) DownloadPricingData() error {
  40. c.DownloadPricingDataLock.Lock()
  41. defer c.DownloadPricingDataLock.Unlock()
  42. pricing := make(map[string]*price)
  43. pvpricing := make(map[string]*price)
  44. header, err := csvutil.Header(price{}, "csv")
  45. if err != nil {
  46. return err
  47. }
  48. fieldsPerRecord := len(header)
  49. var csvr io.Reader
  50. var csverr error
  51. if strings.HasPrefix(c.CSVLocation, "s3://") {
  52. region := os.Getenv("CSV_REGION")
  53. conf := aws.NewConfig().WithRegion(region).WithCredentialsChainVerboseErrors(true)
  54. s3Client := s3.New(session.New(conf))
  55. bucketAndKey := strings.Split(strings.TrimPrefix(c.CSVLocation, "s3://"), "/")
  56. if len(bucketAndKey) == 2 {
  57. out, err := s3Client.GetObject(&s3.GetObjectInput{
  58. Bucket: aws.String(bucketAndKey[0]),
  59. Key: aws.String(bucketAndKey[1]),
  60. })
  61. csverr = err
  62. csvr = out.Body
  63. } else {
  64. c.Pricing = pricing
  65. c.PricingPV = pvpricing
  66. return fmt.Errorf("Invalid s3 URI: %s", c.CSVLocation)
  67. }
  68. } else {
  69. csvr, csverr = GetCsv(c.CSVLocation)
  70. }
  71. if csverr != nil {
  72. klog.Infof("Error reading csv at %s: %s", c.CSVLocation, csverr)
  73. c.Pricing = pricing
  74. c.PricingPV = pvpricing
  75. return nil
  76. }
  77. csvReader := csv.NewReader(csvr)
  78. csvReader.Comma = ','
  79. csvReader.FieldsPerRecord = fieldsPerRecord
  80. dec, err := csvutil.NewDecoder(csvReader, header...)
  81. if err != nil {
  82. c.Pricing = pricing
  83. c.PricingPV = pvpricing
  84. return err
  85. }
  86. for {
  87. p := price{}
  88. err := dec.Decode(&p)
  89. csvParseErr, isCsvParseErr := err.(*csv.ParseError)
  90. if err == io.EOF {
  91. break
  92. } else if err == csvutil.ErrFieldCount || (isCsvParseErr && csvParseErr.Err == csv.ErrFieldCount) {
  93. rec := dec.Record()
  94. if len(rec) != 1 {
  95. klog.V(2).Infof("Expected %d price info fields but received %d: %s", fieldsPerRecord, len(rec), rec)
  96. continue
  97. }
  98. if strings.Index(rec[0], "#") == 0 {
  99. continue
  100. } else {
  101. klog.V(3).Infof("skipping non-CSV line: %s", rec)
  102. continue
  103. }
  104. } else if err != nil {
  105. klog.V(2).Infof("Error during spot info decode: %+v", err)
  106. continue
  107. }
  108. klog.V(4).Infof("Found price info %+v", p)
  109. if p.AssetClass == "pv" {
  110. pvpricing[p.InstanceID] = &p
  111. c.PVMapField = p.InstanceIDField
  112. } else if p.AssetClass == "node" {
  113. pricing[p.InstanceID] = &p
  114. c.NodeMapField = p.InstanceIDField
  115. } else {
  116. klog.Infof("Unrecognized asset class %s, defaulting to node", p.AssetClass)
  117. pricing[p.InstanceID] = &p
  118. c.NodeMapField = p.InstanceIDField
  119. }
  120. }
  121. if len(pricing) > 0 {
  122. c.Pricing = pricing
  123. c.PricingPV = pvpricing
  124. } else {
  125. klog.Infof("[WARNING] No data received from csv")
  126. }
  127. time.AfterFunc(refreshMinutes*time.Minute, func() { c.DownloadPricingData() })
  128. return nil
  129. }
  130. type csvKey struct {
  131. Labels map[string]string
  132. ProviderID string
  133. }
  134. func (k *csvKey) Features() string {
  135. return ""
  136. }
  137. func (k *csvKey) GPUType() string {
  138. return ""
  139. }
  140. func (k *csvKey) ID() string {
  141. return k.ProviderID
  142. }
  143. func (c *CSVProvider) NodePricing(key Key) (*Node, error) {
  144. c.DownloadPricingDataLock.RLock()
  145. defer c.DownloadPricingDataLock.RUnlock()
  146. if p, ok := c.Pricing[key.ID()]; ok {
  147. return &Node{
  148. Cost: p.MarketPriceHourly,
  149. }, nil
  150. }
  151. return nil, fmt.Errorf("Unable to find Node matching %s", key.ID())
  152. }
  153. func NodeValueFromMapField(m string, n *v1.Node) string {
  154. mf := strings.Split(m, ".")
  155. if len(mf) == 2 && mf[0] == "spec" && mf[1] == "providerID" {
  156. return n.Spec.ProviderID
  157. } else if len(mf) > 1 && mf[0] == "metadata" {
  158. if mf[1] == "name" {
  159. return n.Name
  160. } else if mf[1] == "labels" {
  161. lkey := strings.Join(mf[2:len(mf)], "")
  162. return n.Labels[lkey]
  163. } else if mf[1] == "annotations" {
  164. akey := strings.Join(mf[2:len(mf)], "")
  165. return n.Annotations[akey]
  166. } else {
  167. klog.Infof("[ERROR] Unsupported InstanceIDField %s in CSV For Node", m)
  168. return ""
  169. }
  170. } else {
  171. klog.Infof("[ERROR] Unsupported InstanceIDField %s in CSV For Node", m)
  172. return ""
  173. }
  174. }
  175. func PVValueFromMapField(m string, n *v1.PersistentVolume) string {
  176. mf := strings.Split(m, ".")
  177. if len(mf) > 1 && mf[0] == "metadata" {
  178. if mf[1] == "name" {
  179. return n.Name
  180. } else if mf[1] == "labels" {
  181. lkey := strings.Join(mf[2:len(mf)], "")
  182. return n.Labels[lkey]
  183. } else if mf[1] == "annotations" {
  184. akey := strings.Join(mf[2:len(mf)], "")
  185. return n.Annotations[akey]
  186. } else {
  187. klog.V(4).Infof("[ERROR] Unsupported InstanceIDField %s in CSV For PV", m)
  188. return ""
  189. }
  190. } else {
  191. klog.V(4).Infof("[ERROR] Unsupported InstanceIDField %s in CSV For PV", m)
  192. return ""
  193. }
  194. }
  195. func (c *CSVProvider) GetKey(l map[string]string, n *v1.Node) Key {
  196. id := NodeValueFromMapField(c.NodeMapField, n)
  197. return &csvKey{
  198. ProviderID: id,
  199. Labels: l,
  200. }
  201. }
  202. type csvPVKey struct {
  203. Labels map[string]string
  204. ProviderID string
  205. StorageClassName string
  206. StorageClassParameters map[string]string
  207. Name string
  208. DefaultRegion string
  209. }
  210. func (key *csvPVKey) GetStorageClass() string {
  211. return key.StorageClassName
  212. }
  213. func (key *csvPVKey) Features() string {
  214. return key.ProviderID
  215. }
  216. func (c *CSVProvider) GetPVKey(pv *v1.PersistentVolume, parameters map[string]string, defaultRegion string) PVKey {
  217. id := PVValueFromMapField(c.PVMapField, pv)
  218. return &csvPVKey{
  219. Labels: pv.Labels,
  220. ProviderID: id,
  221. StorageClassName: pv.Spec.StorageClassName,
  222. StorageClassParameters: parameters,
  223. Name: pv.Name,
  224. DefaultRegion: defaultRegion,
  225. }
  226. }
  227. func (c *CSVProvider) PVPricing(pvk PVKey) (*PV, error) {
  228. c.DownloadPricingDataLock.RLock()
  229. defer c.DownloadPricingDataLock.RUnlock()
  230. pricing, ok := c.PricingPV[pvk.Features()]
  231. if !ok {
  232. klog.V(4).Infof("Persistent Volume pricing not found for %s: %s", pvk.GetStorageClass(), pvk.Features())
  233. return &PV{}, nil
  234. }
  235. return &PV{
  236. Cost: pricing.MarketPriceHourly,
  237. }, nil
  238. }