mock_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package pricing
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "testing"
  7. "github.com/opencost/opencost/core/pkg/model/kubemodel"
  8. "github.com/opencost/opencost/core/pkg/model/shared"
  9. "github.com/opencost/opencost/core/pkg/reader"
  10. )
  11. func TestMockPricingModule(t *testing.T) {
  12. var source PricingSource
  13. pricingModule, err := NewMockPricingModule()
  14. if err != nil {
  15. t.Fatalf("unexpected error initializing mock repository: %s", err)
  16. }
  17. source = pricingModule
  18. // Simple example of a sink for pricing data (will be database tables in reality)
  19. bufferSize := 10
  20. ingestor := newMockIngestor(bufferSize)
  21. // Test ingestion of mock node reader
  22. nodePricingReader, err := source.NewNodePricingReader(t.Context())
  23. if err != nil {
  24. t.Errorf("unexpected error initializing node reader: %s", err)
  25. }
  26. n, err := ingestor.ingestNodePricing(context.Background(), nodePricingReader)
  27. if err != nil {
  28. t.Errorf("unexpected error ingesting node pricing: %s", err)
  29. }
  30. if n != 39 {
  31. t.Errorf("expected to ingest %d node pricing records; ingested %d", 39, n)
  32. }
  33. nodePricingCount := ingestor.countNodePricing()
  34. if nodePricingCount != 39 {
  35. t.Errorf("expected %d node pricing records; received %d", 39, nodePricingCount)
  36. }
  37. // Test ingestion of mock persistent volume reader
  38. volumePricingReader, err := source.NewPersistentVolumePricingReader(t.Context())
  39. if err != nil {
  40. t.Errorf("unexpected error initializing volume reader: %s", err)
  41. }
  42. n, err = ingestor.ingestPersistentVolumePricing(context.Background(), volumePricingReader)
  43. if err != nil {
  44. t.Errorf("unexpected error ingesting volume pricing: %s", err)
  45. }
  46. if n != 20 {
  47. t.Errorf("expected to ingest %d volume pricing records; ingested %d", 20, n)
  48. }
  49. volumePricingCount := ingestor.countVolumePricing()
  50. if volumePricingCount != 20 {
  51. t.Errorf("expected %d volume pricing records; received %d", 20, volumePricingCount)
  52. }
  53. }
  54. // TestMockGetNodePricing verifies node lookup by properties, that the matching
  55. // entry carries the prices loaded from YAML, and that a missing entry errors.
  56. func TestMockGetNodePricing(t *testing.T) {
  57. mpm := newMock(t)
  58. np, err := mpm.GetNodePricing(t.Context(), NodePricingProperties{
  59. Provider: shared.ProviderAWS,
  60. Region: "us-east-1",
  61. InstanceType: "m5.large",
  62. Provisioning: ProvisioningOnDemand,
  63. })
  64. if err != nil {
  65. t.Fatalf("unexpected error: %v", err)
  66. }
  67. // Guards against the YAML tag regression: prices must actually load.
  68. price, ok := np.Prices[ResourceNode]
  69. if !ok {
  70. t.Fatalf("expected node price to be present, prices=%v", np.Prices)
  71. }
  72. if price.Price != 0.096 {
  73. t.Errorf("expected on-demand price 0.096, got %v", price.Price)
  74. }
  75. // Missing entry should error rather than return a zero value.
  76. if _, err := mpm.GetNodePricing(t.Context(), NodePricingProperties{
  77. Provider: shared.ProviderAWS,
  78. Region: "eu-west-1",
  79. InstanceType: "m5.large",
  80. Provisioning: ProvisioningOnDemand,
  81. }); err == nil {
  82. t.Errorf("expected error for unknown region, got nil")
  83. }
  84. }
  85. // TestMockGetNodePricingProvisioningDiscriminates verifies that on-demand and
  86. // spot entries with otherwise identical properties are not conflated.
  87. func TestMockGetNodePricingProvisioningDiscriminates(t *testing.T) {
  88. mpm := newMock(t)
  89. base := NodePricingProperties{
  90. Provider: shared.ProviderAWS,
  91. Region: "us-east-1",
  92. InstanceType: "m5.large",
  93. }
  94. onDemand := base
  95. onDemand.Provisioning = ProvisioningOnDemand
  96. spot := base
  97. spot.Provisioning = ProvisioningSpot
  98. od, err := mpm.GetNodePricing(t.Context(), onDemand)
  99. if err != nil {
  100. t.Fatalf("unexpected error (on-demand): %v", err)
  101. }
  102. sp, err := mpm.GetNodePricing(t.Context(), spot)
  103. if err != nil {
  104. t.Fatalf("unexpected error (spot): %v", err)
  105. }
  106. if od.Prices[ResourceNode].Price == sp.Prices[ResourceNode].Price {
  107. t.Errorf("expected on-demand and spot to differ, both = %v", od.Prices[ResourceNode].Price)
  108. }
  109. if od.Prices[ResourceNode].Price != 0.096 {
  110. t.Errorf("expected on-demand 0.096, got %v", od.Prices[ResourceNode].Price)
  111. }
  112. if sp.Prices[ResourceNode].Price != 0.043 {
  113. t.Errorf("expected spot 0.043, got %v", sp.Prices[ResourceNode].Price)
  114. }
  115. }
  116. // TestMockGetPersistentVolumePricing verifies volume lookup, that prices load,
  117. // and that a missing entry errors.
  118. func TestMockGetPersistentVolumePricing(t *testing.T) {
  119. mpm := newMock(t)
  120. pv, err := mpm.GetPersistentVolumePricing(t.Context(), PersistentVolumePricingProperties{
  121. Provider: shared.ProviderAWS,
  122. Region: "us-east-1",
  123. VolumeType: VolumeTypeGP3,
  124. })
  125. if err != nil {
  126. t.Fatalf("unexpected error: %v", err)
  127. }
  128. if price, ok := pv.Prices[ResourceStorage]; !ok || price.Price != 0.0001096 {
  129. t.Errorf("expected gp3 storage price 0.0001096, got %v (ok=%t)", pv.Prices[ResourceStorage].Price, ok)
  130. }
  131. if _, err := mpm.GetPersistentVolumePricing(t.Context(), PersistentVolumePricingProperties{
  132. Provider: shared.ProviderAWS,
  133. Region: "us-east-1",
  134. VolumeType: VolumeTypeIO2,
  135. }); err == nil {
  136. t.Errorf("expected error for unknown volume type, got nil")
  137. }
  138. }
  139. // TestMockGetClusterPricing verifies cluster lookup by provider.
  140. func TestMockGetClusterPricing(t *testing.T) {
  141. mpm := newMock(t)
  142. cp, err := mpm.GetClusterPricing(t.Context(), ClusterPricingProperties{Provider: shared.ProviderAWS})
  143. if err != nil {
  144. t.Fatalf("unexpected error: %v", err)
  145. }
  146. if price, ok := cp.Prices[ResourceCluster]; !ok || price.Price != 0.10 {
  147. t.Errorf("expected cluster price 0.10, got %v (ok=%t)", cp.Prices[ResourceCluster].Price, ok)
  148. }
  149. if _, err := mpm.GetClusterPricing(t.Context(), ClusterPricingProperties{Provider: shared.ProviderOracle}); err == nil {
  150. t.Errorf("expected error for unknown provider, got nil")
  151. }
  152. }
  153. // TestMockGetNetworkPricing verifies network lookup, including that the NAT
  154. // gateway flag discriminates between otherwise-identical entries.
  155. func TestMockGetNetworkPricing(t *testing.T) {
  156. mpm := newMock(t)
  157. internet, err := mpm.GetNetworkPricing(t.Context(), NetworkPricingProperties{
  158. Provider: shared.ProviderAWS,
  159. TrafficDirection: kubemodel.TrafficDirectionEgress,
  160. TrafficType: kubemodel.TrafficTypeInternet,
  161. })
  162. if err != nil {
  163. t.Fatalf("unexpected error: %v", err)
  164. }
  165. if internet.Prices[ResourceNetworkTraffic].Price != 0.09 {
  166. t.Errorf("expected internet egress price 0.09, got %v", internet.Prices[ResourceNetworkTraffic].Price)
  167. }
  168. nat, err := mpm.GetNetworkPricing(t.Context(), NetworkPricingProperties{
  169. Provider: shared.ProviderAWS,
  170. TrafficDirection: kubemodel.TrafficDirectionEgress,
  171. TrafficType: kubemodel.TrafficTypeInternet,
  172. IsNatGateway: true,
  173. })
  174. if err != nil {
  175. t.Fatalf("unexpected error (nat): %v", err)
  176. }
  177. if nat.Prices[ResourceNetworkTraffic].Price != 0.045 {
  178. t.Errorf("expected NAT gateway price 0.045, got %v", nat.Prices[ResourceNetworkTraffic].Price)
  179. }
  180. if internet.Prices[ResourceNetworkTraffic].Price == nat.Prices[ResourceNetworkTraffic].Price {
  181. t.Errorf("expected NAT gateway flag to discriminate pricing")
  182. }
  183. if _, err := mpm.GetNetworkPricing(t.Context(), NetworkPricingProperties{
  184. Provider: shared.ProviderAWS,
  185. TrafficDirection: kubemodel.TrafficDirectionIngress,
  186. TrafficType: kubemodel.TrafficTypeInternet,
  187. }); err == nil {
  188. t.Errorf("expected error for unknown traffic direction, got nil")
  189. }
  190. }
  191. // TestMockGetServicePricing verifies service lookup by provider and region.
  192. func TestMockGetServicePricing(t *testing.T) {
  193. mpm := newMock(t)
  194. sp, err := mpm.GetServicePricing(t.Context(), ServicePricingProperties{
  195. Provider: shared.ProviderAWS,
  196. Region: "us-east-1",
  197. })
  198. if err != nil {
  199. t.Fatalf("unexpected error: %v", err)
  200. }
  201. if price, ok := sp.Prices[ResourceService]; !ok || price.Price != 0.025 {
  202. t.Errorf("expected service price 0.025, got %v (ok=%t)", sp.Prices[ResourceService].Price, ok)
  203. }
  204. if _, err := mpm.GetServicePricing(t.Context(), ServicePricingProperties{
  205. Provider: shared.ProviderAWS,
  206. Region: "us-west-2",
  207. }); err == nil {
  208. t.Errorf("expected error for unknown region, got nil")
  209. }
  210. }
  211. // newMock is a helper that constructs a fresh MockPricingModule and fails the
  212. // test if construction errors.
  213. func newMock(t *testing.T) *MockPricingModule {
  214. t.Helper()
  215. mpm, err := NewMockPricingModule()
  216. if err != nil {
  217. t.Fatalf("unexpected error initializing mock pricing module: %v", err)
  218. }
  219. return mpm
  220. }
  221. type mockPricingIngestor struct {
  222. bufferSize int
  223. clusterPricing []*ClusterPricing
  224. networkPricing []*NetworkPricing
  225. nodePricing []*NodePricing
  226. persistentVolumePricing []*PersistentVolumePricing
  227. servicePricing []*ServicePricing
  228. }
  229. func newMockIngestor(bufferSize int) *mockPricingIngestor {
  230. if bufferSize == 0 {
  231. bufferSize = 100
  232. }
  233. return &mockPricingIngestor{
  234. bufferSize: bufferSize,
  235. clusterPricing: []*ClusterPricing{},
  236. networkPricing: []*NetworkPricing{},
  237. nodePricing: []*NodePricing{},
  238. persistentVolumePricing: []*PersistentVolumePricing{},
  239. servicePricing: []*ServicePricing{},
  240. }
  241. }
  242. func (ing *mockPricingIngestor) countNodePricing() int {
  243. return len(ing.nodePricing)
  244. }
  245. func (ing *mockPricingIngestor) ingestNodePricing(ctx context.Context, pricingReader reader.Reader[*NodePricing]) (int, error) {
  246. defer pricingReader.Close()
  247. nodeBuf := make([]*NodePricing, ing.bufferSize)
  248. totalCount := 0
  249. for {
  250. n, err := pricingReader.Read(ctx, nodeBuf)
  251. if n > 0 {
  252. ing.nodePricing = append(ing.nodePricing, nodeBuf[:n]...)
  253. }
  254. if errors.Is(err, reader.Done) {
  255. break
  256. }
  257. if err != nil {
  258. return totalCount, fmt.Errorf("unexpected error reading node pricing: %s", err)
  259. }
  260. totalCount += n
  261. }
  262. return totalCount, nil
  263. }
  264. func (ing *mockPricingIngestor) countVolumePricing() int {
  265. return len(ing.persistentVolumePricing)
  266. }
  267. func (ing *mockPricingIngestor) ingestPersistentVolumePricing(ctx context.Context, pricingReader reader.Reader[*PersistentVolumePricing]) (int, error) {
  268. defer pricingReader.Close()
  269. volBuf := make([]*PersistentVolumePricing, ing.bufferSize)
  270. totalCount := 0
  271. for {
  272. n, err := pricingReader.Read(ctx, volBuf)
  273. if n > 0 {
  274. ing.persistentVolumePricing = append(ing.persistentVolumePricing, volBuf[:n]...)
  275. }
  276. if errors.Is(err, reader.Done) {
  277. break
  278. }
  279. if err != nil {
  280. return totalCount, fmt.Errorf("unexpected error reading volume pricing: %s", err)
  281. }
  282. totalCount += n
  283. }
  284. return totalCount, nil
  285. }