csvprovider.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package cloud
  2. import (
  3. "encoding/csv"
  4. "fmt"
  5. "io"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/aws/aws-sdk-go/aws"
  12. "github.com/aws/aws-sdk-go/aws/session"
  13. "github.com/aws/aws-sdk-go/service/s3"
  14. v1 "k8s.io/api/core/v1"
  15. "k8s.io/klog"
  16. "github.com/jszwec/csvutil"
  17. )
  18. const refreshMinutes = 60
  19. type CSVProvider struct {
  20. *CustomProvider
  21. CSVLocation string
  22. Pricing map[string]*price
  23. NodeMapField string
  24. PricingPV map[string]*price
  25. PVMapField string
  26. UsesRegion bool
  27. DownloadPricingDataLock sync.RWMutex
  28. }
  29. type price struct {
  30. EndTimestamp string `csv:"EndTimestamp"`
  31. InstanceID string `csv:"InstanceID"`
  32. Region string `csv:"Region"`
  33. AssetClass string `csv:"AssetClass"`
  34. InstanceIDField string `csv:"InstanceIDField"`
  35. InstanceType string `csv:"InstanceType"`
  36. MarketPriceHourly string `csv:"MarketPriceHourly"`
  37. Version string `csv:"Version"`
  38. }
  39. func GetCsv(location string) (io.Reader, error) {
  40. return os.Open(location)
  41. }
  42. func (c *CSVProvider) DownloadPricingData() error {
  43. c.DownloadPricingDataLock.Lock()
  44. defer c.DownloadPricingDataLock.Unlock()
  45. pricing := make(map[string]*price)
  46. pvpricing := make(map[string]*price)
  47. header, err := csvutil.Header(price{}, "csv")
  48. if err != nil {
  49. return err
  50. }
  51. fieldsPerRecord := len(header)
  52. var csvr io.Reader
  53. var csverr error
  54. if strings.HasPrefix(c.CSVLocation, "s3://") {
  55. region := os.Getenv("CSV_REGION")
  56. conf := aws.NewConfig().WithRegion(region).WithCredentialsChainVerboseErrors(true)
  57. s3Client := s3.New(session.New(conf))
  58. bucketAndKey := strings.Split(strings.TrimPrefix(c.CSVLocation, "s3://"), "/")
  59. if len(bucketAndKey) == 2 {
  60. out, err := s3Client.GetObject(&s3.GetObjectInput{
  61. Bucket: aws.String(bucketAndKey[0]),
  62. Key: aws.String(bucketAndKey[1]),
  63. })
  64. csverr = err
  65. csvr = out.Body
  66. } else {
  67. c.Pricing = pricing
  68. c.PricingPV = pvpricing
  69. return fmt.Errorf("Invalid s3 URI: %s", c.CSVLocation)
  70. }
  71. } else {
  72. csvr, csverr = GetCsv(c.CSVLocation)
  73. }
  74. if csverr != nil {
  75. klog.Infof("Error reading csv at %s: %s", c.CSVLocation, csverr)
  76. c.Pricing = pricing
  77. c.PricingPV = pvpricing
  78. return nil
  79. }
  80. csvReader := csv.NewReader(csvr)
  81. csvReader.Comma = ','
  82. csvReader.FieldsPerRecord = fieldsPerRecord
  83. dec, err := csvutil.NewDecoder(csvReader, header...)
  84. if err != nil {
  85. c.Pricing = pricing
  86. c.PricingPV = pvpricing
  87. return err
  88. }
  89. for {
  90. p := price{}
  91. err := dec.Decode(&p)
  92. csvParseErr, isCsvParseErr := err.(*csv.ParseError)
  93. if err == io.EOF {
  94. break
  95. } else if err == csvutil.ErrFieldCount || (isCsvParseErr && csvParseErr.Err == csv.ErrFieldCount) {
  96. rec := dec.Record()
  97. if len(rec) != 1 {
  98. klog.V(2).Infof("Expected %d price info fields but received %d: %s", fieldsPerRecord, len(rec), rec)
  99. continue
  100. }
  101. if strings.Index(rec[0], "#") == 0 {
  102. continue
  103. } else {
  104. klog.V(3).Infof("skipping non-CSV line: %s", rec)
  105. continue
  106. }
  107. } else if err != nil {
  108. klog.V(2).Infof("Error during spot info decode: %+v", err)
  109. continue
  110. }
  111. klog.V(4).Infof("Found price info %+v", p)
  112. key := p.InstanceID
  113. if p.Region != "" {
  114. key = fmt.Sprintf("%s,%s", p.Region, p.InstanceID)
  115. c.UsesRegion = true
  116. }
  117. if p.AssetClass == "pv" {
  118. pvpricing[key] = &p
  119. c.PVMapField = p.InstanceIDField
  120. } else if p.AssetClass == "node" {
  121. pricing[key] = &p
  122. c.NodeMapField = p.InstanceIDField
  123. } else {
  124. klog.Infof("Unrecognized asset class %s, defaulting to node", p.AssetClass)
  125. pricing[key] = &p
  126. c.NodeMapField = p.InstanceIDField
  127. }
  128. }
  129. if len(pricing) > 0 {
  130. c.Pricing = pricing
  131. c.PricingPV = pvpricing
  132. } else {
  133. klog.Infof("[WARNING] No data received from csv")
  134. }
  135. time.AfterFunc(refreshMinutes*time.Minute, func() { c.DownloadPricingData() })
  136. return nil
  137. }
  138. type csvKey struct {
  139. Labels map[string]string
  140. ProviderID string
  141. }
  142. func (k *csvKey) Features() string {
  143. return ""
  144. }
  145. func (k *csvKey) GPUType() string {
  146. return ""
  147. }
  148. func (k *csvKey) ID() string {
  149. return k.ProviderID
  150. }
  151. func (c *CSVProvider) NodePricing(key Key) (*Node, error) {
  152. c.DownloadPricingDataLock.RLock()
  153. defer c.DownloadPricingDataLock.RUnlock()
  154. if p, ok := c.Pricing[key.ID()]; ok {
  155. return &Node{
  156. Cost: p.MarketPriceHourly,
  157. }, nil
  158. }
  159. s := strings.Split(key.ID(), ",") // Try without a region to be sure
  160. if len(s) == 2 {
  161. if p, ok := c.Pricing[s[1]]; ok {
  162. return &Node{
  163. Cost: p.MarketPriceHourly,
  164. }, nil
  165. }
  166. }
  167. return nil, fmt.Errorf("Unable to find Node matching %s", key.ID())
  168. }
  169. func NodeValueFromMapField(m string, n *v1.Node, useRegion bool) string {
  170. mf := strings.Split(m, ".")
  171. toReturn := ""
  172. if useRegion {
  173. toReturn = n.Labels[v1.LabelZoneRegion] + ","
  174. }
  175. if len(mf) == 2 && mf[0] == "spec" && mf[1] == "providerID" {
  176. provIdRx := regexp.MustCompile("aws:///([^/]+)/([^/]+)") // It's of the form aws:///us-east-2a/i-0fea4fd46592d050b and we want i-0fea4fd46592d050b, if it exists
  177. for matchNum, group := range provIdRx.FindStringSubmatch(n.Spec.ProviderID) {
  178. if matchNum == 2 {
  179. return toReturn + group
  180. }
  181. }
  182. if strings.HasPrefix(n.Spec.ProviderID, "azure://") {
  183. vmOrScaleSet := strings.TrimPrefix(n.Spec.ProviderID, "azure://")
  184. return toReturn + vmOrScaleSet
  185. }
  186. return toReturn + n.Spec.ProviderID
  187. } else if len(mf) > 1 && mf[0] == "metadata" {
  188. if mf[1] == "name" {
  189. return toReturn + n.Name
  190. } else if mf[1] == "labels" {
  191. lkey := strings.Join(mf[2:len(mf)], "")
  192. return toReturn + n.Labels[lkey]
  193. } else if mf[1] == "annotations" {
  194. akey := strings.Join(mf[2:len(mf)], "")
  195. return toReturn + n.Annotations[akey]
  196. } else {
  197. klog.Infof("[ERROR] Unsupported InstanceIDField %s in CSV For Node", m)
  198. return ""
  199. }
  200. } else {
  201. klog.Infof("[ERROR] Unsupported InstanceIDField %s in CSV For Node", m)
  202. return ""
  203. }
  204. }
  205. func PVValueFromMapField(m string, n *v1.PersistentVolume) string {
  206. mf := strings.Split(m, ".")
  207. if len(mf) > 1 && mf[0] == "metadata" {
  208. if mf[1] == "name" {
  209. return n.Name
  210. } else if mf[1] == "labels" {
  211. lkey := strings.Join(mf[2:len(mf)], "")
  212. return n.Labels[lkey]
  213. } else if mf[1] == "annotations" {
  214. akey := strings.Join(mf[2:len(mf)], "")
  215. return n.Annotations[akey]
  216. } else {
  217. klog.V(4).Infof("[ERROR] Unsupported InstanceIDField %s in CSV For PV", m)
  218. return ""
  219. }
  220. } else {
  221. klog.V(4).Infof("[ERROR] Unsupported InstanceIDField %s in CSV For PV", m)
  222. return ""
  223. }
  224. }
  225. func (c *CSVProvider) GetKey(l map[string]string, n *v1.Node) Key {
  226. id := NodeValueFromMapField(c.NodeMapField, n, c.UsesRegion)
  227. return &csvKey{
  228. ProviderID: id,
  229. Labels: l,
  230. }
  231. }
  232. type csvPVKey struct {
  233. Labels map[string]string
  234. ProviderID string
  235. StorageClassName string
  236. StorageClassParameters map[string]string
  237. Name string
  238. DefaultRegion string
  239. }
  240. func (key *csvPVKey) GetStorageClass() string {
  241. return key.StorageClassName
  242. }
  243. func (key *csvPVKey) Features() string {
  244. return key.ProviderID
  245. }
  246. func (c *CSVProvider) GetPVKey(pv *v1.PersistentVolume, parameters map[string]string, defaultRegion string) PVKey {
  247. id := PVValueFromMapField(c.PVMapField, pv)
  248. return &csvPVKey{
  249. Labels: pv.Labels,
  250. ProviderID: id,
  251. StorageClassName: pv.Spec.StorageClassName,
  252. StorageClassParameters: parameters,
  253. Name: pv.Name,
  254. DefaultRegion: defaultRegion,
  255. }
  256. }
  257. func (c *CSVProvider) PVPricing(pvk PVKey) (*PV, error) {
  258. c.DownloadPricingDataLock.RLock()
  259. defer c.DownloadPricingDataLock.RUnlock()
  260. pricing, ok := c.PricingPV[pvk.Features()]
  261. if !ok {
  262. klog.V(4).Infof("Persistent Volume pricing not found for %s: %s", pvk.GetStorageClass(), pvk.Features())
  263. return &PV{}, nil
  264. }
  265. return &PV{
  266. Cost: pricing.MarketPriceHourly,
  267. }, nil
  268. }