mock.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package pricing
  2. import (
  3. "context"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "path/filepath"
  9. "strings"
  10. "github.com/opencost/opencost/core/pkg/reader"
  11. "gopkg.in/yaml.v3"
  12. )
  13. // MockPricingModule must satisfy the PricingModule interface
  14. var _ PricingModule = (*MockPricingModule)(nil)
  15. type MockPricingModule struct {
  16. ClusterPricing []*ClusterPricing
  17. NetworkPricing []*NetworkPricing
  18. NodePricing []*NodePricing
  19. PersistentVolumePricing []*PersistentVolumePricing
  20. ServicePricing []*ServicePricing
  21. }
  22. func NewMockPricingModule() (*MockPricingModule, error) {
  23. mpm := &MockPricingModule{
  24. ClusterPricing: []*ClusterPricing{},
  25. NetworkPricing: []*NetworkPricing{},
  26. NodePricing: []*NodePricing{},
  27. PersistentVolumePricing: []*PersistentVolumePricing{},
  28. ServicePricing: []*ServicePricing{},
  29. }
  30. // Default
  31. err := mpm.loadTestFile("default.yaml")
  32. if err != nil {
  33. return nil, fmt.Errorf("error loading test default pricing: %w", err)
  34. }
  35. // AWS
  36. err = mpm.loadTestFile("aws.yaml")
  37. if err != nil {
  38. return nil, fmt.Errorf("error loading test AWS pricing: %w", err)
  39. }
  40. // Azure
  41. err = mpm.loadTestFile("azure.yaml")
  42. if err != nil {
  43. return nil, fmt.Errorf("error loading test Azure pricing: %w", err)
  44. }
  45. // GCP
  46. err = mpm.loadTestFile("gcp.yaml")
  47. if err != nil {
  48. return nil, fmt.Errorf("error loading test GCP pricing: %w", err)
  49. }
  50. return mpm, nil
  51. }
  52. func (mpm *MockPricingModule) GetClusterPricing(ctx context.Context, props ClusterPricingProperties) (*ClusterPricing, error) {
  53. if err := ctx.Err(); err != nil {
  54. return nil, err
  55. }
  56. // Search through the mock data for a matching cluster pricing entry
  57. for _, cp := range mpm.ClusterPricing {
  58. if cp.Properties.Provider == props.Provider {
  59. return cp, nil
  60. }
  61. }
  62. return nil, fmt.Errorf("cluster pricing not found for provider=%s", props.Provider)
  63. }
  64. func (mpm *MockPricingModule) NewClusterPricingReader(ctx context.Context) (reader.Reader[*ClusterPricing], error) {
  65. return reader.NewSliceReader(mpm.ClusterPricing), nil
  66. }
  67. func (mpm *MockPricingModule) GetNetworkPricing(ctx context.Context, props NetworkPricingProperties) (*NetworkPricing, error) {
  68. if err := ctx.Err(); err != nil {
  69. return nil, err
  70. }
  71. // Search through the mock data for a matching network pricing entry
  72. for _, np := range mpm.NetworkPricing {
  73. if np.Properties.Provider == props.Provider &&
  74. np.Properties.TrafficDirection == props.TrafficDirection &&
  75. np.Properties.TrafficType == props.TrafficType &&
  76. np.Properties.IsNatGateway == props.IsNatGateway {
  77. return np, nil
  78. }
  79. }
  80. return nil, fmt.Errorf("network pricing not found for provider=%s, trafficDirection=%s, trafficType=%s, isNatGateway=%t",
  81. props.Provider, props.TrafficDirection, props.TrafficType, props.IsNatGateway)
  82. }
  83. func (mpm *MockPricingModule) NewNetworkPricingReader(ctx context.Context) (reader.Reader[*NetworkPricing], error) {
  84. return reader.NewSliceReader(mpm.NetworkPricing), nil
  85. }
  86. func (mpm *MockPricingModule) GetNodePricing(ctx context.Context, props NodePricingProperties) (*NodePricing, error) {
  87. if err := ctx.Err(); err != nil {
  88. return nil, err
  89. }
  90. // Search through the mock data for a matching node pricing entry
  91. for _, np := range mpm.NodePricing {
  92. if np.Properties.Provider == props.Provider &&
  93. np.Properties.Region == props.Region &&
  94. np.Properties.InstanceType == props.InstanceType &&
  95. np.Properties.Provisioning == props.Provisioning &&
  96. np.Properties.Commitment == props.Commitment {
  97. return np, nil
  98. }
  99. }
  100. return nil, fmt.Errorf("node pricing not found for provider=%s, region=%s, instanceType=%s, provisioning=%s, commitment=%s",
  101. props.Provider, props.Region, props.InstanceType, props.Provisioning, props.Commitment)
  102. }
  103. func (mpm *MockPricingModule) NewNodePricingReader(ctx context.Context) (reader.Reader[*NodePricing], error) {
  104. return reader.NewSliceReader(mpm.NodePricing), nil
  105. }
  106. func (mpm *MockPricingModule) GetPersistentVolumePricing(ctx context.Context, props PersistentVolumePricingProperties) (*PersistentVolumePricing, error) {
  107. if err := ctx.Err(); err != nil {
  108. return nil, err
  109. }
  110. // Search through the mock data for a matching volume pricing entry
  111. for _, vp := range mpm.PersistentVolumePricing {
  112. if vp.Properties.Provider == props.Provider &&
  113. vp.Properties.Region == props.Region &&
  114. vp.Properties.VolumeType == props.VolumeType {
  115. return vp, nil
  116. }
  117. }
  118. return nil, fmt.Errorf("volume pricing not found for provider=%s, region=%s, volumeType=%s", props.Provider, props.Region, props.VolumeType)
  119. }
  120. func (mpm *MockPricingModule) NewPersistentVolumePricingReader(ctx context.Context) (reader.Reader[*PersistentVolumePricing], error) {
  121. return reader.NewSliceReader(mpm.PersistentVolumePricing), nil
  122. }
  123. func (mpm *MockPricingModule) GetServicePricing(ctx context.Context, props ServicePricingProperties) (*ServicePricing, error) {
  124. if err := ctx.Err(); err != nil {
  125. return nil, err
  126. }
  127. // Search through the mock data for a matching service pricing entry
  128. for _, sp := range mpm.ServicePricing {
  129. if sp.Properties.Provider == props.Provider &&
  130. sp.Properties.Region == props.Region {
  131. return sp, nil
  132. }
  133. }
  134. return nil, fmt.Errorf("service pricing not found for provider=%s, region=%s", props.Provider, props.Region)
  135. }
  136. func (mpm *MockPricingModule) NewServicePricingReader(ctx context.Context) (reader.Reader[*ServicePricing], error) {
  137. return reader.NewSliceReader(mpm.ServicePricing), nil
  138. }
  139. func (mpm *MockPricingModule) GetPricingSet(ctx context.Context) (*PricingSet, error) {
  140. ps := &PricingSet{
  141. ClusterPricing: mpm.ClusterPricing,
  142. NetworkPricing: mpm.NetworkPricing,
  143. NodePricing: mpm.NodePricing,
  144. PersistentVolumePricing: mpm.PersistentVolumePricing,
  145. ServicePricing: mpm.ServicePricing,
  146. }
  147. return ps, nil
  148. }
  149. func (mpm *MockPricingModule) SourceKind() string {
  150. return "test"
  151. }
  152. func (mpm *MockPricingModule) SourceName() string {
  153. return "mock"
  154. }
  155. func (mpm *MockPricingModule) Checksum(ctx context.Context) (string, error) {
  156. ps, err := mpm.GetPricingSet(ctx)
  157. if err != nil {
  158. return "", fmt.Errorf("getting pricing set: %w", err)
  159. }
  160. return ps.Checksum()
  161. }
  162. //go:embed test/*
  163. var pricingTestFS embed.FS
  164. func (mpm *MockPricingModule) loadTestFile(filename string) error {
  165. path := filepath.Join("test", filename)
  166. bs, err := pricingTestFS.ReadFile(path)
  167. if err != nil {
  168. return fmt.Errorf("failed to read embedded pricing file: %w", err)
  169. }
  170. var set *PricingSet
  171. // Detect file format based on extension
  172. ext := strings.ToLower(filepath.Ext(filename))
  173. switch ext {
  174. case ".json":
  175. err = json.Unmarshal(bs, &set)
  176. if err != nil {
  177. return fmt.Errorf("failed to parse json: %w", err)
  178. }
  179. case ".yaml", ".yml":
  180. err = yaml.Unmarshal(bs, &set)
  181. if err != nil {
  182. return fmt.Errorf("failed to parse yaml: %w", err)
  183. }
  184. default:
  185. return fmt.Errorf("unsupported file format: %s (expected .json, .yaml, or .yml)", ext)
  186. }
  187. if set == nil {
  188. return errors.New("nil set")
  189. }
  190. mpm.ClusterPricing = append(mpm.ClusterPricing, set.ClusterPricing...)
  191. mpm.NetworkPricing = append(mpm.NetworkPricing, set.NetworkPricing...)
  192. mpm.NodePricing = append(mpm.NodePricing, set.NodePricing...)
  193. mpm.PersistentVolumePricing = append(mpm.PersistentVolumePricing, set.PersistentVolumePricing...)
  194. mpm.ServicePricing = append(mpm.ServicePricing, set.ServicePricing...)
  195. return nil
  196. }