package basic import ( "context" "os" "testing" "github.com/opencost/opencost/core/pkg/pricing" "github.com/opencost/opencost/core/pkg/reader" "github.com/opencost/opencost/core/pkg/storage" "github.com/opencost/opencost/core/pkg/unit" "github.com/stretchr/testify/require" ) func TestPricingModule(t *testing.T) { memoryPricingStore := pricing.NewMemoryPricingStore() filePricingStore, err := pricing.NewStoragePricingStore(t.Context(), newFileStorage(t), "pricing.json") require.NoError(t, err) stores := map[string]pricing.PricingStore{ "MemoryPricingStore": memoryPricingStore, "StoragePricingStore": filePricingStore, } for name, store := range stores { t.Run(name, testPricingModuleWithStore(store)) } } func testPricingModuleWithStore(store pricing.PricingStore) func(t *testing.T) { return func(t *testing.T) { ctx := t.Context() pm, err := NewBasicPricingModule(store) require.NoError(t, err) t.Run("DefaultPricing", func(t *testing.T) { testDefaultPricing(t, ctx, pm) }) t.Run("SetCurrency", func(t *testing.T) { testSetCurrency(t, ctx, pm) }) t.Run("SetNodePricePerCPUCoreHour", func(t *testing.T) { testSetNodePricePerCPUCoreHour(t, ctx, pm) }) t.Run("SetNodePricePerRAMGiBHour", func(t *testing.T) { testSetNodePricePerRAMGiBHour(t, ctx, pm) }) t.Run("SetNodePricePerGPUHour", func(t *testing.T) { testSetNodePricePerGPUHour(t, ctx, pm) }) t.Run("SetNodePricePerLocalDiskGiBHour", func(t *testing.T) { testSetNodePricePerLocalDiskGiBHour(t, ctx, pm) }) t.Run("SetVolumePricePerStorageGiBHour", func(t *testing.T) { testSetVolumePricePerStorageGiBHour(t, ctx, pm) }) t.Run("NewNodePricingReader", func(t *testing.T) { testNewNodePricingReader(t, ctx, pm) }) t.Run("NewVolumePricingReader", func(t *testing.T) { testNewVolumePricingReader(t, ctx, pm) }) t.Run("ModulePersistence", func(t *testing.T) { // Create a new PricingModule with the same store pm2, err := NewBasicPricingModule(store) require.NoError(t, err) // Verify that pricing persists np, err := pm2.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing: %v", err) } if np == nil { t.Fatal("Expected node pricing to be persisted") } }) } } // testDefaultPricing verifies that a freshly created PricingModule contains default pricing func testDefaultPricing(t *testing.T, ctx context.Context, pm *PricingModule) { // Test default currency currency := pm.GetCurrency() if currency != unit.USD { t.Errorf("Expected default currency to be USD, got %s", currency) } // Test default node pricing np, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing: %v", err) } if np == nil { t.Fatal("Expected node pricing to exist") } prices, err := np.Prices.GetPricesInCurrency(unit.USD) if err != nil { t.Fatalf("Failed to get prices in USD: %v", err) } // Verify default prices exist foundCPU := false foundRAM := false foundGPU := false for _, price := range prices { switch price.Unit { case unit.VCPUHour: foundCPU = true if price.Price != DefaultNodePricePerVCPUHour { t.Errorf("Expected CPU price to be %f, got %f", DefaultNodePricePerVCPUHour, price.Price) } case unit.RAMGiBHour: foundRAM = true if price.Price != DefaultNodePricePerRAMGiBHour { t.Errorf("Expected RAM price to be %f, got %f", DefaultNodePricePerRAMGiBHour, price.Price) } case unit.GPUHour: foundGPU = true if price.Price != DefaultNodePricePerGPUHour { t.Errorf("Expected GPU price to be %f, got %f", DefaultNodePricePerGPUHour, price.Price) } } } if !foundCPU { t.Error("Expected to find CPU pricing") } if !foundRAM { t.Error("Expected to find RAM pricing") } if !foundGPU { t.Error("Expected to find GPU pricing") } // Test default volume pricing vp, err := pm.getVolumePricing(ctx) if err != nil { t.Fatalf("Failed to get volume pricing: %v", err) } if vp == nil { t.Fatal("Expected volume pricing to exist") } volumePrices, err := vp.Prices.GetPricesInCurrency(unit.USD) if err != nil { t.Fatalf("Failed to get volume prices in USD: %v", err) } foundVolume := false for _, price := range volumePrices { if price.Unit == unit.StorageGiBHour { foundVolume = true if price.Price != DefaultVolumePricePerGiBHour { t.Errorf("Expected volume price to be %f, got %f", DefaultVolumePricePerGiBHour, price.Price) } } } if !foundVolume { t.Error("Expected to find volume pricing") } } // testSetCurrency tests the SetCurrency function func testSetCurrency(t *testing.T, ctx context.Context, pm *PricingModule) { // Get current pricing to compare later npBefore, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing before currency change: %v", err) } pricesBefore, err := npBefore.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get prices before currency change: %v", err) } vpBefore, err := pm.getVolumePricing(ctx) if err != nil { t.Fatalf("Failed to get volume pricing before currency change: %v", err) } volumePricesBefore, err := vpBefore.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get volume prices before currency change: %v", err) } // Change currency to EUR err = pm.SetCurrency(ctx, unit.EUR) if err != nil { t.Fatalf("Failed to set currency: %v", err) } // Verify currency changed currency := pm.GetCurrency() if currency != unit.EUR { t.Errorf("Expected currency to be EUR, got %s", currency) } // Verify node pricing units and prices remain the same, only currency changed npAfter, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing after currency change: %v", err) } pricesAfter, err := npAfter.Prices.GetPricesInCurrency(unit.EUR) if err != nil { t.Fatalf("Failed to get prices after currency change: %v", err) } if len(pricesBefore) != len(pricesAfter) { t.Errorf("Expected same number of prices, got %d before and %d after", len(pricesBefore), len(pricesAfter)) } // Create maps for easier comparison beforeMap := make(map[unit.Unit]float64) for _, p := range pricesBefore { beforeMap[p.Unit] = p.Price } afterMap := make(map[unit.Unit]float64) for _, p := range pricesAfter { afterMap[p.Unit] = p.Price if p.Currency != unit.EUR { t.Errorf("Expected currency to be EUR, got %s", p.Currency) } } // Verify units and prices match for unit, priceBefore := range beforeMap { priceAfter, ok := afterMap[unit] if !ok { t.Errorf("Unit %s not found after currency change", unit) continue } if priceBefore != priceAfter { t.Errorf("Price for unit %s changed from %f to %f", unit, priceBefore, priceAfter) } } // Verify volume pricing units and prices remain the same vpAfter, err := pm.getVolumePricing(ctx) if err != nil { t.Fatalf("Failed to get volume pricing after currency change: %v", err) } volumePricesAfter, err := vpAfter.Prices.GetPricesInCurrency(unit.EUR) if err != nil { t.Fatalf("Failed to get volume prices after currency change: %v", err) } if len(volumePricesBefore) != len(volumePricesAfter) { t.Errorf("Expected same number of volume prices, got %d before and %d after", len(volumePricesBefore), len(volumePricesAfter)) } for i, priceBefore := range volumePricesBefore { priceAfter := volumePricesAfter[i] if priceAfter.Currency != unit.EUR { t.Errorf("Expected currency to be EUR, got %s", priceAfter.Currency) } if priceBefore.Unit != priceAfter.Unit { t.Errorf("Unit changed from %s to %s", priceBefore.Unit, priceAfter.Unit) } if priceBefore.Price != priceAfter.Price { t.Errorf("Price changed from %f to %f", priceBefore.Price, priceAfter.Price) } } // Change back to USD for other tests err = pm.SetCurrency(ctx, unit.USD) if err != nil { t.Fatalf("Failed to set currency back to USD: %v", err) } } // testSetNodePricePerCPUCoreHour tests the SetNodePricePerCPUCoreHour function func testSetNodePricePerCPUCoreHour(t *testing.T, ctx context.Context, pm *PricingModule) { newPrice := 0.075 err := pm.SetNodePricePerCPUCoreHour(ctx, newPrice) if err != nil { t.Fatalf("Failed to set CPU price: %v", err) } // Verify the price was set np, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing: %v", err) } prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get prices: %v", err) } found := false for _, price := range prices { if price.Unit == unit.VCPUHour { found = true if price.Price != newPrice { t.Errorf("Expected CPU price to be %f, got %f", newPrice, price.Price) } } } if !found { t.Error("Expected to find CPU pricing") } } // testSetNodePricePerRAMGiBHour tests the SetNodePricePerRAMGiBHour function func testSetNodePricePerRAMGiBHour(t *testing.T, ctx context.Context, pm *PricingModule) { newPrice := 0.008 err := pm.SetNodePricePerRAMGiBHour(ctx, newPrice) if err != nil { t.Fatalf("Failed to set RAM price: %v", err) } // Verify the price was set np, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing: %v", err) } prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get prices: %v", err) } found := false for _, price := range prices { if price.Unit == unit.RAMGiBHour { found = true if price.Price != newPrice { t.Errorf("Expected RAM price to be %f, got %f", newPrice, price.Price) } } } if !found { t.Error("Expected to find RAM pricing") } } // testSetNodePricePerGPUHour tests the SetNodePricePerGPUHour function func testSetNodePricePerGPUHour(t *testing.T, ctx context.Context, pm *PricingModule) { newPrice := 2.0 err := pm.SetNodePricePerGPUHour(ctx, newPrice) if err != nil { t.Fatalf("Failed to set GPU price: %v", err) } // Verify the price was set np, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing: %v", err) } prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get prices: %v", err) } found := false for _, price := range prices { if price.Unit == unit.GPUHour { found = true if price.Price != newPrice { t.Errorf("Expected GPU price to be %f, got %f", newPrice, price.Price) } } } if !found { t.Error("Expected to find GPU pricing") } } // testSetNodePricePerLocalDiskGiBHour tests the SetNodePricePerLocalDiskGiBHour function func testSetNodePricePerLocalDiskGiBHour(t *testing.T, ctx context.Context, pm *PricingModule) { newPrice := 0.0007 err := pm.SetNodePricePerLocalDiskGiBHour(ctx, newPrice) if err != nil { t.Fatalf("Failed to set local disk price: %v", err) } // Verify the price was set np, err := pm.getNodePricing(ctx) if err != nil { t.Fatalf("Failed to get node pricing: %v", err) } prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get prices: %v", err) } found := false for _, price := range prices { if price.Unit == unit.StorageGiBHour { found = true if price.Price != newPrice { t.Errorf("Expected local disk price to be %f, got %f", newPrice, price.Price) } } } if !found { t.Error("Expected to find local disk pricing") } } // testSetVolumePricePerStorageGiBHour tests the SetVolumePricePerStorageGiBHour function func testSetVolumePricePerStorageGiBHour(t *testing.T, ctx context.Context, pm *PricingModule) { newPrice := 0.0003 err := pm.SetVolumePricePerStorageGiBHour(ctx, newPrice) if err != nil { t.Fatalf("Failed to set volume storage price: %v", err) } // Verify the price was set vp, err := pm.getVolumePricing(ctx) if err != nil { t.Fatalf("Failed to get volume pricing: %v", err) } prices, err := vp.Prices.GetPricesInCurrency(pm.GetCurrency()) if err != nil { t.Fatalf("Failed to get prices: %v", err) } found := false for _, price := range prices { if price.Unit == unit.StorageGiBHour { found = true if price.Price != newPrice { t.Errorf("Expected volume storage price to be %f, got %f", newPrice, price.Price) } } } if !found { t.Error("Expected to find volume storage pricing") } } // testNewNodePricingReader tests the NewNodePricingReader function func testNewNodePricingReader(t *testing.T, ctx context.Context, pm *PricingModule) { // Test that NewNodePricingReader always produces a reader rdr, err := pm.NewNodePricingReader(ctx) if err != nil { t.Fatalf("Failed to create node pricing reader: %v", err) } if rdr == nil { t.Fatal("Expected reader to be non-nil") } // Test that the reader produces precisely one *NodePricing struct dst := make([]*pricing.NodePricing, 10) // Buffer larger than expected count := 0 for { n, err := rdr.Read(ctx, dst) count += n // Verify all read items are non-nil for i := 0; i < n; i++ { if dst[i] == nil { t.Error("Expected non-nil NodePricing") } } if err == reader.Done { break } if err != nil { t.Fatalf("Reader error: %v", err) } } if count != 1 { t.Errorf("Expected reader to produce exactly 1 NodePricing, got %d", count) } // Clean up if err := rdr.Close(); err != nil { t.Errorf("Failed to close reader: %v", err) } } // testNewVolumePricingReader tests the NewVolumePricingReader function func testNewVolumePricingReader(t *testing.T, ctx context.Context, pm *PricingModule) { // Test that NewVolumePricingReader always produces a reader rdr, err := pm.NewVolumePricingReader(ctx) if err != nil { t.Fatalf("Failed to create volume pricing reader: %v", err) } if rdr == nil { t.Fatal("Expected reader to be non-nil") } // Test that the reader produces precisely one *VolumePricing struct dst := make([]*pricing.VolumePricing, 10) // Buffer larger than expected count := 0 for { n, err := rdr.Read(ctx, dst) count += n // Verify all read items are non-nil for i := 0; i < n; i++ { if dst[i] == nil { t.Error("Expected non-nil VolumePricing") } } if err == reader.Done { break } if err != nil { t.Fatalf("Reader error: %v", err) } } if count != 1 { t.Errorf("Expected reader to produce exactly 1 VolumePricing, got %d", count) } // Clean up if err := rdr.Close(); err != nil { t.Errorf("Failed to close reader: %v", err) } } func newFileStorage(t *testing.T) storage.Storage { tempDir, err := os.MkdirTemp("", "pricing-test-*") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) } defer os.RemoveAll(tempDir) return storage.NewFileStorage(tempDir) }