2
0

module_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. package basic
  2. import (
  3. "context"
  4. "os"
  5. "testing"
  6. "github.com/opencost/opencost/core/pkg/pricing"
  7. "github.com/opencost/opencost/core/pkg/reader"
  8. "github.com/opencost/opencost/core/pkg/storage"
  9. "github.com/opencost/opencost/core/pkg/unit"
  10. "github.com/stretchr/testify/require"
  11. )
  12. func TestPricingModule(t *testing.T) {
  13. memoryPricingStore := pricing.NewMemoryPricingStore()
  14. filePricingStore, err := pricing.NewStoragePricingStore(t.Context(), newFileStorage(t), "pricing.json")
  15. require.NoError(t, err)
  16. stores := map[string]pricing.PricingStore{
  17. "MemoryPricingStore": memoryPricingStore,
  18. "StoragePricingStore": filePricingStore,
  19. }
  20. for name, store := range stores {
  21. t.Run(name, testPricingModuleWithStore(store))
  22. }
  23. }
  24. func testPricingModuleWithStore(store pricing.PricingStore) func(t *testing.T) {
  25. return func(t *testing.T) {
  26. ctx := t.Context()
  27. pm, err := NewBasicPricingModule(store)
  28. require.NoError(t, err)
  29. t.Run("DefaultPricing", func(t *testing.T) {
  30. testDefaultPricing(t, ctx, pm)
  31. })
  32. t.Run("SetCurrency", func(t *testing.T) {
  33. testSetCurrency(t, ctx, pm)
  34. })
  35. t.Run("SetNodePricePerCPUCoreHour", func(t *testing.T) {
  36. testSetNodePricePerCPUCoreHour(t, ctx, pm)
  37. })
  38. t.Run("SetNodePricePerRAMGiBHour", func(t *testing.T) {
  39. testSetNodePricePerRAMGiBHour(t, ctx, pm)
  40. })
  41. t.Run("SetNodePricePerGPUHour", func(t *testing.T) {
  42. testSetNodePricePerGPUHour(t, ctx, pm)
  43. })
  44. t.Run("SetNodePricePerLocalDiskGiBHour", func(t *testing.T) {
  45. testSetNodePricePerLocalDiskGiBHour(t, ctx, pm)
  46. })
  47. t.Run("SetVolumePricePerStorageGiBHour", func(t *testing.T) {
  48. testSetVolumePricePerStorageGiBHour(t, ctx, pm)
  49. })
  50. t.Run("NewNodePricingReader", func(t *testing.T) {
  51. testNewNodePricingReader(t, ctx, pm)
  52. })
  53. t.Run("NewVolumePricingReader", func(t *testing.T) {
  54. testNewVolumePricingReader(t, ctx, pm)
  55. })
  56. t.Run("ModulePersistence", func(t *testing.T) {
  57. // Create a new PricingModule with the same store
  58. pm2, err := NewBasicPricingModule(store)
  59. require.NoError(t, err)
  60. // Verify that pricing persists
  61. np, err := pm2.getNodePricing(ctx)
  62. if err != nil {
  63. t.Fatalf("Failed to get node pricing: %v", err)
  64. }
  65. if np == nil {
  66. t.Fatal("Expected node pricing to be persisted")
  67. }
  68. })
  69. }
  70. }
  71. // testDefaultPricing verifies that a freshly created PricingModule contains default pricing
  72. func testDefaultPricing(t *testing.T, ctx context.Context, pm *PricingModule) {
  73. // Test default currency
  74. currency := pm.GetCurrency()
  75. if currency != unit.USD {
  76. t.Errorf("Expected default currency to be USD, got %s", currency)
  77. }
  78. // Test default node pricing
  79. np, err := pm.getNodePricing(ctx)
  80. if err != nil {
  81. t.Fatalf("Failed to get node pricing: %v", err)
  82. }
  83. if np == nil {
  84. t.Fatal("Expected node pricing to exist")
  85. }
  86. prices, err := np.Prices.GetPricesInCurrency(unit.USD)
  87. if err != nil {
  88. t.Fatalf("Failed to get prices in USD: %v", err)
  89. }
  90. // Verify default prices exist
  91. foundCPU := false
  92. foundRAM := false
  93. foundGPU := false
  94. for _, price := range prices {
  95. switch price.Unit {
  96. case unit.VCPUHour:
  97. foundCPU = true
  98. if price.Price != DefaultNodePricePerVCPUHour {
  99. t.Errorf("Expected CPU price to be %f, got %f", DefaultNodePricePerVCPUHour, price.Price)
  100. }
  101. case unit.RAMGiBHour:
  102. foundRAM = true
  103. if price.Price != DefaultNodePricePerRAMGiBHour {
  104. t.Errorf("Expected RAM price to be %f, got %f", DefaultNodePricePerRAMGiBHour, price.Price)
  105. }
  106. case unit.GPUHour:
  107. foundGPU = true
  108. if price.Price != DefaultNodePricePerGPUHour {
  109. t.Errorf("Expected GPU price to be %f, got %f", DefaultNodePricePerGPUHour, price.Price)
  110. }
  111. }
  112. }
  113. if !foundCPU {
  114. t.Error("Expected to find CPU pricing")
  115. }
  116. if !foundRAM {
  117. t.Error("Expected to find RAM pricing")
  118. }
  119. if !foundGPU {
  120. t.Error("Expected to find GPU pricing")
  121. }
  122. // Test default volume pricing
  123. vp, err := pm.getVolumePricing(ctx)
  124. if err != nil {
  125. t.Fatalf("Failed to get volume pricing: %v", err)
  126. }
  127. if vp == nil {
  128. t.Fatal("Expected volume pricing to exist")
  129. }
  130. volumePrices, err := vp.Prices.GetPricesInCurrency(unit.USD)
  131. if err != nil {
  132. t.Fatalf("Failed to get volume prices in USD: %v", err)
  133. }
  134. foundVolume := false
  135. for _, price := range volumePrices {
  136. if price.Unit == unit.StorageGiBHour {
  137. foundVolume = true
  138. if price.Price != DefaultVolumePricePerGiBHour {
  139. t.Errorf("Expected volume price to be %f, got %f", DefaultVolumePricePerGiBHour, price.Price)
  140. }
  141. }
  142. }
  143. if !foundVolume {
  144. t.Error("Expected to find volume pricing")
  145. }
  146. }
  147. // testSetCurrency tests the SetCurrency function
  148. func testSetCurrency(t *testing.T, ctx context.Context, pm *PricingModule) {
  149. // Get current pricing to compare later
  150. npBefore, err := pm.getNodePricing(ctx)
  151. if err != nil {
  152. t.Fatalf("Failed to get node pricing before currency change: %v", err)
  153. }
  154. pricesBefore, err := npBefore.Prices.GetPricesInCurrency(pm.GetCurrency())
  155. if err != nil {
  156. t.Fatalf("Failed to get prices before currency change: %v", err)
  157. }
  158. vpBefore, err := pm.getVolumePricing(ctx)
  159. if err != nil {
  160. t.Fatalf("Failed to get volume pricing before currency change: %v", err)
  161. }
  162. volumePricesBefore, err := vpBefore.Prices.GetPricesInCurrency(pm.GetCurrency())
  163. if err != nil {
  164. t.Fatalf("Failed to get volume prices before currency change: %v", err)
  165. }
  166. // Change currency to EUR
  167. err = pm.SetCurrency(ctx, unit.EUR)
  168. if err != nil {
  169. t.Fatalf("Failed to set currency: %v", err)
  170. }
  171. // Verify currency changed
  172. currency := pm.GetCurrency()
  173. if currency != unit.EUR {
  174. t.Errorf("Expected currency to be EUR, got %s", currency)
  175. }
  176. // Verify node pricing units and prices remain the same, only currency changed
  177. npAfter, err := pm.getNodePricing(ctx)
  178. if err != nil {
  179. t.Fatalf("Failed to get node pricing after currency change: %v", err)
  180. }
  181. pricesAfter, err := npAfter.Prices.GetPricesInCurrency(unit.EUR)
  182. if err != nil {
  183. t.Fatalf("Failed to get prices after currency change: %v", err)
  184. }
  185. if len(pricesBefore) != len(pricesAfter) {
  186. t.Errorf("Expected same number of prices, got %d before and %d after", len(pricesBefore), len(pricesAfter))
  187. }
  188. // Create maps for easier comparison
  189. beforeMap := make(map[unit.Unit]float64)
  190. for _, p := range pricesBefore {
  191. beforeMap[p.Unit] = p.Price
  192. }
  193. afterMap := make(map[unit.Unit]float64)
  194. for _, p := range pricesAfter {
  195. afterMap[p.Unit] = p.Price
  196. if p.Currency != unit.EUR {
  197. t.Errorf("Expected currency to be EUR, got %s", p.Currency)
  198. }
  199. }
  200. // Verify units and prices match
  201. for unit, priceBefore := range beforeMap {
  202. priceAfter, ok := afterMap[unit]
  203. if !ok {
  204. t.Errorf("Unit %s not found after currency change", unit)
  205. continue
  206. }
  207. if priceBefore != priceAfter {
  208. t.Errorf("Price for unit %s changed from %f to %f", unit, priceBefore, priceAfter)
  209. }
  210. }
  211. // Verify volume pricing units and prices remain the same
  212. vpAfter, err := pm.getVolumePricing(ctx)
  213. if err != nil {
  214. t.Fatalf("Failed to get volume pricing after currency change: %v", err)
  215. }
  216. volumePricesAfter, err := vpAfter.Prices.GetPricesInCurrency(unit.EUR)
  217. if err != nil {
  218. t.Fatalf("Failed to get volume prices after currency change: %v", err)
  219. }
  220. if len(volumePricesBefore) != len(volumePricesAfter) {
  221. t.Errorf("Expected same number of volume prices, got %d before and %d after", len(volumePricesBefore), len(volumePricesAfter))
  222. }
  223. for i, priceBefore := range volumePricesBefore {
  224. priceAfter := volumePricesAfter[i]
  225. if priceAfter.Currency != unit.EUR {
  226. t.Errorf("Expected currency to be EUR, got %s", priceAfter.Currency)
  227. }
  228. if priceBefore.Unit != priceAfter.Unit {
  229. t.Errorf("Unit changed from %s to %s", priceBefore.Unit, priceAfter.Unit)
  230. }
  231. if priceBefore.Price != priceAfter.Price {
  232. t.Errorf("Price changed from %f to %f", priceBefore.Price, priceAfter.Price)
  233. }
  234. }
  235. // Change back to USD for other tests
  236. err = pm.SetCurrency(ctx, unit.USD)
  237. if err != nil {
  238. t.Fatalf("Failed to set currency back to USD: %v", err)
  239. }
  240. }
  241. // testSetNodePricePerCPUCoreHour tests the SetNodePricePerCPUCoreHour function
  242. func testSetNodePricePerCPUCoreHour(t *testing.T, ctx context.Context, pm *PricingModule) {
  243. newPrice := 0.075
  244. err := pm.SetNodePricePerCPUCoreHour(ctx, newPrice)
  245. if err != nil {
  246. t.Fatalf("Failed to set CPU price: %v", err)
  247. }
  248. // Verify the price was set
  249. np, err := pm.getNodePricing(ctx)
  250. if err != nil {
  251. t.Fatalf("Failed to get node pricing: %v", err)
  252. }
  253. prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency())
  254. if err != nil {
  255. t.Fatalf("Failed to get prices: %v", err)
  256. }
  257. found := false
  258. for _, price := range prices {
  259. if price.Unit == unit.VCPUHour {
  260. found = true
  261. if price.Price != newPrice {
  262. t.Errorf("Expected CPU price to be %f, got %f", newPrice, price.Price)
  263. }
  264. }
  265. }
  266. if !found {
  267. t.Error("Expected to find CPU pricing")
  268. }
  269. }
  270. // testSetNodePricePerRAMGiBHour tests the SetNodePricePerRAMGiBHour function
  271. func testSetNodePricePerRAMGiBHour(t *testing.T, ctx context.Context, pm *PricingModule) {
  272. newPrice := 0.008
  273. err := pm.SetNodePricePerRAMGiBHour(ctx, newPrice)
  274. if err != nil {
  275. t.Fatalf("Failed to set RAM price: %v", err)
  276. }
  277. // Verify the price was set
  278. np, err := pm.getNodePricing(ctx)
  279. if err != nil {
  280. t.Fatalf("Failed to get node pricing: %v", err)
  281. }
  282. prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency())
  283. if err != nil {
  284. t.Fatalf("Failed to get prices: %v", err)
  285. }
  286. found := false
  287. for _, price := range prices {
  288. if price.Unit == unit.RAMGiBHour {
  289. found = true
  290. if price.Price != newPrice {
  291. t.Errorf("Expected RAM price to be %f, got %f", newPrice, price.Price)
  292. }
  293. }
  294. }
  295. if !found {
  296. t.Error("Expected to find RAM pricing")
  297. }
  298. }
  299. // testSetNodePricePerGPUHour tests the SetNodePricePerGPUHour function
  300. func testSetNodePricePerGPUHour(t *testing.T, ctx context.Context, pm *PricingModule) {
  301. newPrice := 2.0
  302. err := pm.SetNodePricePerGPUHour(ctx, newPrice)
  303. if err != nil {
  304. t.Fatalf("Failed to set GPU price: %v", err)
  305. }
  306. // Verify the price was set
  307. np, err := pm.getNodePricing(ctx)
  308. if err != nil {
  309. t.Fatalf("Failed to get node pricing: %v", err)
  310. }
  311. prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency())
  312. if err != nil {
  313. t.Fatalf("Failed to get prices: %v", err)
  314. }
  315. found := false
  316. for _, price := range prices {
  317. if price.Unit == unit.GPUHour {
  318. found = true
  319. if price.Price != newPrice {
  320. t.Errorf("Expected GPU price to be %f, got %f", newPrice, price.Price)
  321. }
  322. }
  323. }
  324. if !found {
  325. t.Error("Expected to find GPU pricing")
  326. }
  327. }
  328. // testSetNodePricePerLocalDiskGiBHour tests the SetNodePricePerLocalDiskGiBHour function
  329. func testSetNodePricePerLocalDiskGiBHour(t *testing.T, ctx context.Context, pm *PricingModule) {
  330. newPrice := 0.0007
  331. err := pm.SetNodePricePerLocalDiskGiBHour(ctx, newPrice)
  332. if err != nil {
  333. t.Fatalf("Failed to set local disk price: %v", err)
  334. }
  335. // Verify the price was set
  336. np, err := pm.getNodePricing(ctx)
  337. if err != nil {
  338. t.Fatalf("Failed to get node pricing: %v", err)
  339. }
  340. prices, err := np.Prices.GetPricesInCurrency(pm.GetCurrency())
  341. if err != nil {
  342. t.Fatalf("Failed to get prices: %v", err)
  343. }
  344. found := false
  345. for _, price := range prices {
  346. if price.Unit == unit.StorageGiBHour {
  347. found = true
  348. if price.Price != newPrice {
  349. t.Errorf("Expected local disk price to be %f, got %f", newPrice, price.Price)
  350. }
  351. }
  352. }
  353. if !found {
  354. t.Error("Expected to find local disk pricing")
  355. }
  356. }
  357. // testSetVolumePricePerStorageGiBHour tests the SetVolumePricePerStorageGiBHour function
  358. func testSetVolumePricePerStorageGiBHour(t *testing.T, ctx context.Context, pm *PricingModule) {
  359. newPrice := 0.0003
  360. err := pm.SetVolumePricePerStorageGiBHour(ctx, newPrice)
  361. if err != nil {
  362. t.Fatalf("Failed to set volume storage price: %v", err)
  363. }
  364. // Verify the price was set
  365. vp, err := pm.getVolumePricing(ctx)
  366. if err != nil {
  367. t.Fatalf("Failed to get volume pricing: %v", err)
  368. }
  369. prices, err := vp.Prices.GetPricesInCurrency(pm.GetCurrency())
  370. if err != nil {
  371. t.Fatalf("Failed to get prices: %v", err)
  372. }
  373. found := false
  374. for _, price := range prices {
  375. if price.Unit == unit.StorageGiBHour {
  376. found = true
  377. if price.Price != newPrice {
  378. t.Errorf("Expected volume storage price to be %f, got %f", newPrice, price.Price)
  379. }
  380. }
  381. }
  382. if !found {
  383. t.Error("Expected to find volume storage pricing")
  384. }
  385. }
  386. // testNewNodePricingReader tests the NewNodePricingReader function
  387. func testNewNodePricingReader(t *testing.T, ctx context.Context, pm *PricingModule) {
  388. // Test that NewNodePricingReader always produces a reader
  389. rdr, err := pm.NewNodePricingReader(ctx)
  390. if err != nil {
  391. t.Fatalf("Failed to create node pricing reader: %v", err)
  392. }
  393. if rdr == nil {
  394. t.Fatal("Expected reader to be non-nil")
  395. }
  396. // Test that the reader produces precisely one *NodePricing struct
  397. dst := make([]*pricing.NodePricing, 10) // Buffer larger than expected
  398. count := 0
  399. for {
  400. n, err := rdr.Read(ctx, dst)
  401. count += n
  402. // Verify all read items are non-nil
  403. for i := 0; i < n; i++ {
  404. if dst[i] == nil {
  405. t.Error("Expected non-nil NodePricing")
  406. }
  407. }
  408. if err == reader.Done {
  409. break
  410. }
  411. if err != nil {
  412. t.Fatalf("Reader error: %v", err)
  413. }
  414. }
  415. if count != 1 {
  416. t.Errorf("Expected reader to produce exactly 1 NodePricing, got %d", count)
  417. }
  418. // Clean up
  419. if err := rdr.Close(); err != nil {
  420. t.Errorf("Failed to close reader: %v", err)
  421. }
  422. }
  423. // testNewVolumePricingReader tests the NewVolumePricingReader function
  424. func testNewVolumePricingReader(t *testing.T, ctx context.Context, pm *PricingModule) {
  425. // Test that NewVolumePricingReader always produces a reader
  426. rdr, err := pm.NewVolumePricingReader(ctx)
  427. if err != nil {
  428. t.Fatalf("Failed to create volume pricing reader: %v", err)
  429. }
  430. if rdr == nil {
  431. t.Fatal("Expected reader to be non-nil")
  432. }
  433. // Test that the reader produces precisely one *VolumePricing struct
  434. dst := make([]*pricing.VolumePricing, 10) // Buffer larger than expected
  435. count := 0
  436. for {
  437. n, err := rdr.Read(ctx, dst)
  438. count += n
  439. // Verify all read items are non-nil
  440. for i := 0; i < n; i++ {
  441. if dst[i] == nil {
  442. t.Error("Expected non-nil VolumePricing")
  443. }
  444. }
  445. if err == reader.Done {
  446. break
  447. }
  448. if err != nil {
  449. t.Fatalf("Reader error: %v", err)
  450. }
  451. }
  452. if count != 1 {
  453. t.Errorf("Expected reader to produce exactly 1 VolumePricing, got %d", count)
  454. }
  455. // Clean up
  456. if err := rdr.Close(); err != nil {
  457. t.Errorf("Failed to close reader: %v", err)
  458. }
  459. }
  460. func newFileStorage(t *testing.T) storage.Storage {
  461. tempDir, err := os.MkdirTemp("", "pricing-test-*")
  462. if err != nil {
  463. t.Fatalf("Failed to create temp directory: %v", err)
  464. }
  465. defer os.RemoveAll(tempDir)
  466. return storage.NewFileStorage(tempDir)
  467. }