module.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. package public
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "sync"
  7. "time"
  8. "github.com/opencost/opencost/core/pkg/log"
  9. "github.com/opencost/opencost/core/pkg/pricing"
  10. "github.com/opencost/opencost/core/pkg/reader"
  11. "github.com/opencost/opencost/core/pkg/unit"
  12. )
  13. type PricingModuleConfig struct {
  14. Provider pricing.Provider
  15. Currency unit.Currency
  16. RefreshInterval time.Duration
  17. }
  18. type PricingModule struct {
  19. config PricingModuleConfig
  20. Providers *ProviderPricing `json:"provider" yaml:"provider"`
  21. pricingSet *pricing.PricingSet
  22. mu sync.RWMutex
  23. stopCh chan struct{}
  24. doneCh chan struct{}
  25. }
  26. func NewPricingModule(config PricingModuleConfig) (*PricingModule, error) {
  27. pm := &PricingModule{
  28. config: config,
  29. Providers: &ProviderPricing{},
  30. stopCh: make(chan struct{}),
  31. doneCh: make(chan struct{}),
  32. }
  33. ctx := context.Background()
  34. // Generate pricing data directly from the provider API
  35. pricingSet, err := GeneratePricingForProvider(config.Provider, config.Currency)
  36. if err != nil {
  37. return nil, fmt.Errorf("failed to generate pricing: %w", err)
  38. }
  39. // Store the pricing set for reader access
  40. pm.pricingSet = pricingSet
  41. err = pm.indexPricingSet(ctx, pricingSet)
  42. if err != nil {
  43. return nil, fmt.Errorf("failed to index pricing: %w", err)
  44. }
  45. // Start background refresh if configured
  46. if config.RefreshInterval > 0 {
  47. go pm.backgroundRefresh()
  48. log.Infof("Started background pricing refresh with interval: %v", config.RefreshInterval)
  49. }
  50. return pm, nil
  51. }
  52. type ProviderPricing map[pricing.Provider]*InstanceTypePricing
  53. type InstanceTypePricing map[string]*RegionPricing
  54. type RegionPricing map[string]*pricing.Prices
  55. func (pm *PricingModule) indexPricingSet(_ context.Context, pricingSet *pricing.PricingSet) error {
  56. providers := make(ProviderPricing)
  57. // Index nodes
  58. for _, node := range pricingSet.Nodes {
  59. provider := node.Properties.Provider
  60. instanceType := node.Properties.InstanceType
  61. region := node.Properties.Region
  62. // Instance type map
  63. if providers[provider] == nil {
  64. instanceMap := make(InstanceTypePricing)
  65. providers[provider] = &instanceMap
  66. }
  67. // Region map
  68. if (*providers[provider])[instanceType] == nil {
  69. regionMap := make(RegionPricing)
  70. (*providers[provider])[instanceType] = &regionMap
  71. }
  72. (*(*providers[provider])[instanceType])[region] = &node.Prices
  73. }
  74. // Index volumes
  75. for _, volume := range pricingSet.Volumes {
  76. provider := volume.Properties.Provider
  77. volumeType := string(volume.Properties.VolumeType)
  78. region := volume.Properties.Region
  79. // Instance type map
  80. if providers[provider] == nil {
  81. instanceMap := make(InstanceTypePricing)
  82. providers[provider] = &instanceMap
  83. }
  84. // Region map
  85. if (*providers[provider])[volumeType] == nil {
  86. regionMap := make(RegionPricing)
  87. (*providers[provider])[volumeType] = &regionMap
  88. }
  89. (*(*providers[provider])[volumeType])[region] = &volume.Prices
  90. }
  91. pm.Providers = &providers
  92. log.Infof("Indexed %d node pricing records and %d volume pricing records for provider %s (%s)",
  93. len(pricingSet.Nodes), len(pricingSet.Volumes), pm.config.Provider, pm.config.Currency)
  94. return nil
  95. }
  96. // GetNodePricing provides fast lookup for node pricing by provider, instance type, and region
  97. func (pm *PricingModule) GetNodePricing(provider pricing.Provider, instanceType string, region string) (*pricing.NodePricing, error) {
  98. pm.mu.RLock()
  99. defer pm.mu.RUnlock()
  100. if pm.Providers == nil {
  101. return nil, fmt.Errorf("pricing not loaded")
  102. }
  103. providerPricing := (*pm.Providers)[provider]
  104. if providerPricing == nil {
  105. return nil, fmt.Errorf("provider %s not found", provider)
  106. }
  107. instancePricing := (*providerPricing)[instanceType]
  108. if instancePricing == nil {
  109. return nil, fmt.Errorf("instance type %s not found for provider %s", instanceType, provider)
  110. }
  111. regionPricing := (*instancePricing)[region]
  112. if regionPricing == nil {
  113. return nil, fmt.Errorf("region %s not found for instance type %s in provider %s", region, instanceType, provider)
  114. }
  115. // Reconstruct NodePricing from Prices
  116. return &pricing.NodePricing{
  117. Properties: pricing.NodePricingProperties{
  118. Provider: provider,
  119. InstanceType: instanceType,
  120. Region: region,
  121. },
  122. Prices: *regionPricing,
  123. }, nil
  124. }
  125. // GetVolumePricing provides fast lookup for node pricing by provider, instance type, and region
  126. func (pm *PricingModule) GetVolumePricing(provider pricing.Provider, volumeType string, region string) (*pricing.VolumePricing, error) {
  127. pm.mu.RLock()
  128. defer pm.mu.RUnlock()
  129. if pm.Providers == nil {
  130. return nil, fmt.Errorf("pricing not loaded")
  131. }
  132. providerPricing := (*pm.Providers)[provider]
  133. if providerPricing == nil {
  134. return nil, fmt.Errorf("provider %s not found", provider)
  135. }
  136. instancePricing := (*providerPricing)[volumeType]
  137. if instancePricing == nil {
  138. return nil, fmt.Errorf("volume type %s not found for provider %s", volumeType, provider)
  139. }
  140. regionPricing := (*instancePricing)[region]
  141. if regionPricing == nil {
  142. return nil, fmt.Errorf("region %s not found for volume type %s in provider %s", region, volumeType, provider)
  143. }
  144. // Reconstruct NodePricing from Prices
  145. return &pricing.VolumePricing{
  146. Properties: pricing.VolumePricingProperties{
  147. Provider: provider,
  148. VolumeType: pricing.VolumeType(volumeType),
  149. Region: region,
  150. },
  151. Prices: *regionPricing,
  152. }, nil
  153. }
  154. func (pm *PricingModule) NewNodePricingReader(ctx context.Context) (reader.Reader[*pricing.NodePricing], error) {
  155. pm.mu.RLock()
  156. defer pm.mu.RUnlock()
  157. return reader.NewSliceReader(pm.pricingSet.Nodes), nil
  158. }
  159. func (pm *PricingModule) NewVolumePricingReader(ctx context.Context) (reader.Reader[*pricing.VolumePricing], error) {
  160. pm.mu.RLock()
  161. defer pm.mu.RUnlock()
  162. return reader.NewSliceReader(pm.pricingSet.Volumes), nil
  163. }
  164. // GetPricingSet returns the current in-memory pricing set
  165. func (pm *PricingModule) GetPricingSet() *pricing.PricingSet {
  166. pm.mu.RLock()
  167. defer pm.mu.RUnlock()
  168. return pm.pricingSet
  169. }
  170. // ComparePricingSet compares the current in-memory pricing set with a new one
  171. // Returns true if they are identical, false if different
  172. func (pm *PricingModule) ComparePricingSet(newPricingSet *pricing.PricingSet) (bool, error) {
  173. pm.mu.RLock()
  174. defer pm.mu.RUnlock()
  175. if pm.pricingSet == nil {
  176. return false, fmt.Errorf("current pricing set is nil")
  177. }
  178. if newPricingSet == nil {
  179. return false, fmt.Errorf("new pricing set is nil")
  180. }
  181. // Compare by serializing both to JSON and computing checksums
  182. currentJSON, err := pm.serializePricingSet(pm.pricingSet)
  183. if err != nil {
  184. return false, fmt.Errorf("failed to serialize current pricing set: %w", err)
  185. }
  186. newJSON, err := pm.serializePricingSet(newPricingSet)
  187. if err != nil {
  188. return false, fmt.Errorf("failed to serialize new pricing set: %w", err)
  189. }
  190. return string(currentJSON) == string(newJSON), nil
  191. }
  192. // UpdatePricingSet replaces the current pricing set with a new one and re-indexes it
  193. func (pm *PricingModule) UpdatePricingSet(ctx context.Context, newPricingSet *pricing.PricingSet) error {
  194. if newPricingSet == nil {
  195. return fmt.Errorf("new pricing set is nil")
  196. }
  197. pm.mu.Lock()
  198. defer pm.mu.Unlock()
  199. // Store the new pricing set
  200. pm.pricingSet = newPricingSet
  201. // Re-index the pricing data
  202. err := pm.indexPricingSet(ctx, newPricingSet)
  203. if err != nil {
  204. return fmt.Errorf("failed to index new pricing set: %w", err)
  205. }
  206. log.Infof("Updated pricing set: %d node pricing records and %d volume pricing records",
  207. len(newPricingSet.Nodes), len(newPricingSet.Volumes))
  208. return nil
  209. }
  210. // serializePricingSet converts a pricing set to JSON bytes for comparison
  211. func (pm *PricingModule) serializePricingSet(ps *pricing.PricingSet) ([]byte, error) {
  212. return json.Marshal(ps)
  213. }
  214. // backgroundRefresh periodically fetches new pricing data and updates the module
  215. func (pm *PricingModule) backgroundRefresh() {
  216. defer close(pm.doneCh)
  217. ticker := time.NewTicker(pm.config.RefreshInterval)
  218. defer ticker.Stop()
  219. for {
  220. select {
  221. case <-ticker.C:
  222. log.Infof("Starting scheduled pricing refresh for %s (%s)", pm.config.Provider, pm.config.Currency)
  223. // Fetch new pricing data
  224. newPricingSet, err := GeneratePricingForProvider(pm.config.Provider, pm.config.Currency)
  225. if err != nil {
  226. log.Errorf("Failed to refresh pricing data: %v", err)
  227. continue
  228. }
  229. // Compare with existing data
  230. isIdentical, err := pm.ComparePricingSet(newPricingSet)
  231. if err != nil {
  232. log.Errorf("Failed to compare pricing data: %v", err)
  233. continue
  234. }
  235. if isIdentical {
  236. log.Infof("Pricing data unchanged, skipping update")
  237. continue
  238. }
  239. // Update with new data
  240. ctx := context.Background()
  241. if err := pm.UpdatePricingSet(ctx, newPricingSet); err != nil {
  242. log.Errorf("Failed to update pricing data: %v", err)
  243. continue
  244. }
  245. log.Infof("Successfully refreshed pricing data")
  246. case <-pm.stopCh:
  247. log.Infof("Stopping background pricing refresh")
  248. return
  249. }
  250. }
  251. }
  252. // Stop gracefully stops the background refresh goroutine
  253. func (pm *PricingModule) Stop() {
  254. if pm.config.RefreshInterval > 0 {
  255. close(pm.stopCh)
  256. <-pm.doneCh
  257. log.Infof("Background pricing refresh stopped")
  258. }
  259. }