package public import ( "context" "fmt" "sync" "time" "github.com/opencost/opencost/core/pkg/log" "github.com/opencost/opencost/core/pkg/model/shared" "github.com/opencost/opencost/core/pkg/pricing" "github.com/opencost/opencost/core/pkg/reader" "github.com/opencost/opencost/core/pkg/unit" ) // PricingModule must satisfy the pricing.PricingModule interface var _ pricing.PricingModule = (*PricingModule)(nil) type PricingModuleConfig struct { Provider shared.Provider Currency unit.Currency RefreshInterval time.Duration } type PricingModule struct { config PricingModuleConfig Providers *ProviderPricing `json:"provider" yaml:"provider"` pricingSet *pricing.PricingSet mu sync.RWMutex stopCh chan struct{} doneCh chan struct{} } func NewPricingModule(config PricingModuleConfig) (*PricingModule, error) { pm := &PricingModule{ config: config, Providers: &ProviderPricing{}, stopCh: make(chan struct{}), doneCh: make(chan struct{}), } ctx := context.Background() // Generate pricing data directly from the provider API pricingSet, err := GeneratePricingForProvider(config.Provider, config.Currency) if err != nil { return nil, fmt.Errorf("failed to generate pricing: %w", err) } // Store the pricing set for reader access pm.pricingSet = pricingSet err = pm.indexPricingSet(ctx, pricingSet) if err != nil { return nil, fmt.Errorf("failed to index pricing: %w", err) } // Start background refresh if configured if config.RefreshInterval > 0 { go pm.backgroundRefresh() log.Infof("Started background pricing refresh with interval: %v", config.RefreshInterval) } return pm, nil } type ProviderPricing map[shared.Provider]*InstanceTypePricing type InstanceTypePricing map[string]*RegionPricing type RegionPricing map[string]*pricing.Prices func (pm *PricingModule) indexPricingSet(_ context.Context, pricingSet *pricing.PricingSet) error { providers := make(ProviderPricing) // Index nodes for _, node := range pricingSet.NodePricing { provider := node.Properties.Provider instanceType := node.Properties.InstanceType region := node.Properties.Region // Instance type map if providers[provider] == nil { instanceMap := make(InstanceTypePricing) providers[provider] = &instanceMap } // Region map if (*providers[provider])[instanceType] == nil { regionMap := make(RegionPricing) (*providers[provider])[instanceType] = ®ionMap } (*(*providers[provider])[instanceType])[region] = &node.Prices } // Index volumes for _, volume := range pricingSet.PersistentVolumePricing { provider := volume.Properties.Provider volumeType := string(volume.Properties.VolumeType) region := volume.Properties.Region // Instance type map if providers[provider] == nil { instanceMap := make(InstanceTypePricing) providers[provider] = &instanceMap } // Region map if (*providers[provider])[volumeType] == nil { regionMap := make(RegionPricing) (*providers[provider])[volumeType] = ®ionMap } (*(*providers[provider])[volumeType])[region] = &volume.Prices } pm.Providers = &providers log.Infof("Indexed %d node pricing records and %d volume pricing records for provider %s (%s)", len(pricingSet.NodePricing), len(pricingSet.PersistentVolumePricing), pm.config.Provider, pm.config.Currency) return nil } // GetNodePricing provides fast lookup for node pricing by provider, instance type, and region func (pm *PricingModule) GetNodePricing(ctx context.Context, props pricing.NodePricingProperties) (*pricing.NodePricing, error) { if err := ctx.Err(); err != nil { return nil, err } pm.mu.RLock() defer pm.mu.RUnlock() if pm.Providers == nil { return nil, fmt.Errorf("pricing not loaded") } provider := props.Provider instanceType := props.InstanceType region := props.Region providerPricing := (*pm.Providers)[provider] if providerPricing == nil { return nil, fmt.Errorf("provider %s not found", provider) } instancePricing := (*providerPricing)[instanceType] if instancePricing == nil { return nil, fmt.Errorf("instance type %s not found for provider %s", instanceType, provider) } regionPricing := (*instancePricing)[region] if regionPricing == nil { return nil, fmt.Errorf("region %s not found for instance type %s in provider %s", region, instanceType, provider) } // Reconstruct NodePricing from Prices return &pricing.NodePricing{ Properties: pricing.NodePricingProperties{ Provider: provider, InstanceType: instanceType, Region: region, }, Prices: *regionPricing, }, nil } // GetPersistentVolumePricing provides fast lookup for volume pricing by provider, volume type, and region func (pm *PricingModule) GetPersistentVolumePricing(ctx context.Context, props pricing.PersistentVolumePricingProperties) (*pricing.PersistentVolumePricing, error) { if err := ctx.Err(); err != nil { return nil, err } pm.mu.RLock() defer pm.mu.RUnlock() if pm.Providers == nil { return nil, fmt.Errorf("pricing not loaded") } provider := props.Provider volumeType := string(props.VolumeType) region := props.Region providerPricing := (*pm.Providers)[provider] if providerPricing == nil { return nil, fmt.Errorf("provider %s not found", provider) } instancePricing := (*providerPricing)[volumeType] if instancePricing == nil { return nil, fmt.Errorf("volume type %s not found for provider %s", volumeType, provider) } regionPricing := (*instancePricing)[region] if regionPricing == nil { return nil, fmt.Errorf("region %s not found for volume type %s in provider %s", region, volumeType, provider) } // Reconstruct NodePricing from Prices return &pricing.PersistentVolumePricing{ Properties: pricing.PersistentVolumePricingProperties{ Provider: provider, VolumeType: pricing.VolumeType(volumeType), Region: region, }, Prices: *regionPricing, }, nil } func (pm *PricingModule) NewNodePricingReader(ctx context.Context) (reader.Reader[*pricing.NodePricing], error) { pm.mu.RLock() defer pm.mu.RUnlock() return reader.NewSliceReader(pm.pricingSet.NodePricing), nil } func (pm *PricingModule) NewPersistentVolumePricingReader(ctx context.Context) (reader.Reader[*pricing.PersistentVolumePricing], error) { pm.mu.RLock() defer pm.mu.RUnlock() return reader.NewSliceReader(pm.pricingSet.PersistentVolumePricing), nil } // GetClusterPricing returns cluster pricing matching the given provider. func (pm *PricingModule) GetClusterPricing(ctx context.Context, props pricing.ClusterPricingProperties) (*pricing.ClusterPricing, error) { if err := ctx.Err(); err != nil { return nil, err } pm.mu.RLock() defer pm.mu.RUnlock() if pm.pricingSet == nil { return nil, fmt.Errorf("pricing not loaded") } for _, cp := range pm.pricingSet.ClusterPricing { if cp.Properties.Provider == props.Provider { return cp, nil } } return nil, fmt.Errorf("cluster pricing not found for provider=%s", props.Provider) } func (pm *PricingModule) NewClusterPricingReader(ctx context.Context) (reader.Reader[*pricing.ClusterPricing], error) { pm.mu.RLock() defer pm.mu.RUnlock() return reader.NewSliceReader(pm.pricingSet.ClusterPricing), nil } // GetNetworkPricing returns network pricing matching the given provider, traffic // direction, traffic type, and NAT gateway flag. func (pm *PricingModule) GetNetworkPricing(ctx context.Context, props pricing.NetworkPricingProperties) (*pricing.NetworkPricing, error) { if err := ctx.Err(); err != nil { return nil, err } pm.mu.RLock() defer pm.mu.RUnlock() if pm.pricingSet == nil { return nil, fmt.Errorf("pricing not loaded") } for _, np := range pm.pricingSet.NetworkPricing { if np.Properties.Provider == props.Provider && np.Properties.TrafficDirection == props.TrafficDirection && np.Properties.TrafficType == props.TrafficType && np.Properties.IsNatGateway == props.IsNatGateway { return np, nil } } return nil, fmt.Errorf("network pricing not found for provider=%s, trafficDirection=%s, trafficType=%s, isNatGateway=%t", props.Provider, props.TrafficDirection, props.TrafficType, props.IsNatGateway) } func (pm *PricingModule) NewNetworkPricingReader(ctx context.Context) (reader.Reader[*pricing.NetworkPricing], error) { pm.mu.RLock() defer pm.mu.RUnlock() return reader.NewSliceReader(pm.pricingSet.NetworkPricing), nil } // GetServicePricing returns service pricing matching the given provider and region. func (pm *PricingModule) GetServicePricing(ctx context.Context, props pricing.ServicePricingProperties) (*pricing.ServicePricing, error) { if err := ctx.Err(); err != nil { return nil, err } pm.mu.RLock() defer pm.mu.RUnlock() if pm.pricingSet == nil { return nil, fmt.Errorf("pricing not loaded") } for _, sp := range pm.pricingSet.ServicePricing { if sp.Properties.Provider == props.Provider && sp.Properties.Region == props.Region { return sp, nil } } return nil, fmt.Errorf("service pricing not found for provider=%s, region=%s", props.Provider, props.Region) } func (pm *PricingModule) NewServicePricingReader(ctx context.Context) (reader.Reader[*pricing.ServicePricing], error) { pm.mu.RLock() defer pm.mu.RUnlock() return reader.NewSliceReader(pm.pricingSet.ServicePricing), nil } // GetPricingSet returns the current in-memory pricing set func (pm *PricingModule) GetPricingSet(ctx context.Context) (*pricing.PricingSet, error) { if err := ctx.Err(); err != nil { return nil, err } pm.mu.RLock() defer pm.mu.RUnlock() return pm.pricingSet, nil } // TODO: Make this a const? This string is correct, but is also defined in KCM. func (pm *PricingModule) SourceKind() string { return "public" } // TODO: This seems like a reasonable choice for a source name... but let's think about it a bit more. func (pm *PricingModule) SourceName() string { return string(pm.config.Provider) } func (pm *PricingModule) Checksum(ctx context.Context) (string, error) { pm.mu.RLock() defer pm.mu.RUnlock() if pm.pricingSet == nil { return "", fmt.Errorf("pricing not loaded") } return pm.pricingSet.Checksum() } // ComparePricingSet compares the current in-memory pricing set with a new one // Returns true if they are identical, false if different func (pm *PricingModule) ComparePricingSet(newPricingSet *pricing.PricingSet) (bool, error) { pm.mu.RLock() defer pm.mu.RUnlock() sum, err := pm.pricingSet.Checksum() if err != nil { return false, fmt.Errorf("failed to checksum current pricing set: %w", err) } newSum, err := newPricingSet.Checksum() if err != nil { return false, fmt.Errorf("failed to serialize new pricing set: %w", err) } return sum == newSum, nil } // UpdatePricingSet replaces the current pricing set with a new one and re-indexes it func (pm *PricingModule) UpdatePricingSet(ctx context.Context, newPricingSet *pricing.PricingSet) error { if newPricingSet == nil { return fmt.Errorf("new pricing set is nil") } pm.mu.Lock() defer pm.mu.Unlock() // Store the new pricing set pm.pricingSet = newPricingSet // Re-index the pricing data err := pm.indexPricingSet(ctx, newPricingSet) if err != nil { return fmt.Errorf("failed to index new pricing set: %w", err) } log.Infof("Updated pricing set: %d node pricing records and %d volume pricing records", len(newPricingSet.NodePricing), len(newPricingSet.PersistentVolumePricing)) return nil } // backgroundRefresh periodically fetches new pricing data and updates the module func (pm *PricingModule) backgroundRefresh() { defer close(pm.doneCh) ticker := time.NewTicker(pm.config.RefreshInterval) defer ticker.Stop() for { select { case <-ticker.C: log.Infof("Starting scheduled pricing refresh for %s (%s)", pm.config.Provider, pm.config.Currency) // Fetch new pricing data newPricingSet, err := GeneratePricingForProvider(pm.config.Provider, pm.config.Currency) if err != nil { log.Errorf("Failed to refresh pricing data: %v", err) continue } // Compare with existing data isIdentical, err := pm.ComparePricingSet(newPricingSet) if err != nil { log.Errorf("Failed to compare pricing data: %v", err) continue } if isIdentical { log.Infof("Pricing data unchanged, skipping update") continue } // Update with new data ctx := context.Background() if err := pm.UpdatePricingSet(ctx, newPricingSet); err != nil { log.Errorf("Failed to update pricing data: %v", err) continue } log.Infof("Successfully refreshed pricing data") case <-pm.stopCh: log.Infof("Stopping background pricing refresh") return } } } // Stop gracefully stops the background refresh goroutine func (pm *PricingModule) Stop() { if pm.config.RefreshInterval > 0 { close(pm.stopCh) <-pm.doneCh log.Infof("Background pricing refresh stopped") } }