mock.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. package pricing
  2. import (
  3. "context"
  4. "embed"
  5. "encoding/json"
  6. "fmt"
  7. "path/filepath"
  8. "strings"
  9. "github.com/opencost/opencost/core/pkg/reader"
  10. "gopkg.in/yaml.v3"
  11. )
  12. type MockPricingRepository struct {
  13. NodePricing []*NodePricing
  14. VolumePricing []*VolumePricing
  15. }
  16. func NewMockPricingRepository() (*MockPricingRepository, error) {
  17. repo := &MockPricingRepository{
  18. NodePricing: []*NodePricing{},
  19. VolumePricing: []*VolumePricing{},
  20. }
  21. // Default
  22. defaultPricingSet, err := loadTestFile("default.yaml")
  23. if err != nil {
  24. return nil, fmt.Errorf("error loading test default pricing: %w", err)
  25. }
  26. repo.NodePricing = append(repo.NodePricing, defaultPricingSet.Nodes...)
  27. repo.VolumePricing = append(repo.VolumePricing, defaultPricingSet.Volumes...)
  28. // AWS
  29. awsPricingSet, err := loadTestFile("aws.yaml")
  30. if err != nil {
  31. return nil, fmt.Errorf("error loading test AWS pricing: %w", err)
  32. }
  33. repo.NodePricing = append(repo.NodePricing, awsPricingSet.Nodes...)
  34. repo.VolumePricing = append(repo.VolumePricing, awsPricingSet.Volumes...)
  35. // Azure
  36. azurePricingSet, err := loadTestFile("azure.yaml")
  37. if err != nil {
  38. return nil, fmt.Errorf("error loading test Azure pricing: %w", err)
  39. }
  40. repo.NodePricing = append(repo.NodePricing, azurePricingSet.Nodes...)
  41. repo.VolumePricing = append(repo.VolumePricing, azurePricingSet.Volumes...)
  42. // GCP
  43. gcpPricingSet, err := loadTestFile("gcp.yaml")
  44. if err != nil {
  45. return nil, fmt.Errorf("error loading test GCP pricing: %w", err)
  46. }
  47. repo.NodePricing = append(repo.NodePricing, gcpPricingSet.Nodes...)
  48. repo.VolumePricing = append(repo.VolumePricing, gcpPricingSet.Volumes...)
  49. return repo, nil
  50. }
  51. func (repo *MockPricingRepository) NewNodePricingReader(ctx context.Context) (reader.Reader[*NodePricing], error) {
  52. return reader.NewSliceReader(repo.NodePricing), nil
  53. }
  54. func (repo *MockPricingRepository) GetNodePricing(provider Provider, instanceType string, region string) (*NodePricing, error) {
  55. // Search through the mock data for a matching node pricing entry
  56. for _, np := range repo.NodePricing {
  57. if np.Properties.Provider == provider &&
  58. np.Properties.InstanceType == instanceType &&
  59. np.Properties.Region == region {
  60. return np, nil
  61. }
  62. }
  63. return nil, fmt.Errorf("node pricing not found for provider=%s, instanceType=%s, region=%s", provider, instanceType, region)
  64. }
  65. func (repo *MockPricingRepository) NewVolumePricingReader(ctx context.Context) (reader.Reader[*VolumePricing], error) {
  66. return reader.NewSliceReader(repo.VolumePricing), nil
  67. }
  68. func (repo *MockPricingRepository) GetVolumePricing(props VolumePricingProperties) (*VolumePricing, error) {
  69. // Search through the mock data for a matching volume pricing entry
  70. for _, vp := range repo.VolumePricing {
  71. if vp.Properties.Provider == props.Provider &&
  72. vp.Properties.Region == props.Region &&
  73. vp.Properties.VolumeType == props.VolumeType {
  74. return vp, nil
  75. }
  76. }
  77. return nil, fmt.Errorf("volume pricing not found for provider=%s, region=%s, volumeType=%s", props.Provider, props.Region, props.VolumeType)
  78. }
  79. //go:embed test/*
  80. var pricingTestFS embed.FS
  81. func loadTestFile(filename string) (*PricingSet, error) {
  82. path := filepath.Join("test", filename)
  83. bs, err := pricingTestFS.ReadFile(path)
  84. if err != nil {
  85. panic(fmt.Errorf("failed to read embedded pricing file: %w", err))
  86. }
  87. var set *PricingSet
  88. // Detect file format based on extension
  89. ext := strings.ToLower(filepath.Ext(filename))
  90. switch ext {
  91. case ".json":
  92. err = json.Unmarshal(bs, &set)
  93. if err != nil {
  94. return nil, fmt.Errorf("failed to parse json: %w", err)
  95. }
  96. case ".yaml", ".yml":
  97. err = yaml.Unmarshal(bs, &set)
  98. if err != nil {
  99. return nil, fmt.Errorf("failed to parse yaml: %w", err)
  100. }
  101. default:
  102. return nil, fmt.Errorf("unsupported file format: %s (expected .json, .yaml, or .yml)", ext)
  103. }
  104. return set, nil
  105. }