Browse Source

fix: Address review feedback on spot price history implementation

- Use %v instead of %s for error formatting in log.Errorf
- Add defer cleanup for refreshRunning to prevent deadlocks on panic
- Reduce log level from Info to Debug for cache miss fetches
- Add concurrency test verifying only one fetch per key under contention

Signed-off-by: Warwick Peatey <warwick@automatic.systems>
Assisted-by: Claude Code
Claude 4 tuần trước cách đây
mục cha
commit
e170fc4f30

+ 1 - 1
pkg/cloud/aws/provider.go

@@ -1043,7 +1043,7 @@ func (aws *AWS) DownloadPricingData() error {
 		aws.SpotPriceHistoryError = nil
 		aws.SpotPriceHistoryCache, aws.SpotPriceHistoryError = aws.initializeSpotPriceHistoryCache()
 		if aws.SpotPriceHistoryError != nil {
-			log.Errorf("Failed to initialize spot price history manager: %s", aws.SpotPriceHistoryError)
+			log.Errorf("Failed to initialize spot price history manager: %v", aws.SpotPriceHistoryError)
 		}
 	}
 

+ 9 - 3
pkg/cloud/aws/spotpricehistory.go

@@ -89,6 +89,14 @@ func (sph *SpotPriceHistoryCache) GetSpotPrice(region, instanceType, availabilit
 	sph.refreshRunning[key] = true
 	sph.mutex.Unlock()
 
+	// Ensure refreshRunning is always cleared, even if the fetcher panics.
+	defer func() {
+		sph.mutex.Lock()
+		delete(sph.refreshRunning, key)
+		sph.refreshCond.Broadcast()
+		sph.mutex.Unlock()
+	}()
+
 	// Fetch the entry
 	entry, err := sph.fetcher.FetchSpotPrice(key)
 	if err != nil {
@@ -102,8 +110,6 @@ func (sph *SpotPriceHistoryCache) GetSpotPrice(region, instanceType, availabilit
 	// Store it into the cache
 	sph.mutex.Lock()
 	sph.cache[key] = entry
-	delete(sph.refreshRunning, key)
-	sph.refreshCond.Broadcast()
 	sph.mutex.Unlock()
 	return entry, entry.Error
 }
@@ -144,7 +150,7 @@ func (a *AWSSpotPriceHistoryFetcher) getEC2Client(region string) *ec2.Client {
 }
 
 func (a *AWSSpotPriceHistoryFetcher) FetchSpotPrice(key SpotPriceHistoryKey) (*SpotPriceHistoryEntry, error) {
-	log.Infof("Retrieving spot price history for %s", key)
+	log.Debugf("Retrieving spot price history for %s", key)
 	client := a.getEC2Client(key.Region)
 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 	defer cancel()

+ 40 - 0
pkg/cloud/aws/spotpricehistory_test.go

@@ -2,6 +2,8 @@ package aws
 
 import (
 	"errors"
+	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 )
@@ -77,6 +79,44 @@ func TestSpotPriceHistoryCache_GetSpotPrice_CacheMiss(t *testing.T) {
 	}
 }
 
+func TestSpotPriceHistoryCache_GetSpotPrice_ConcurrentSameKey(t *testing.T) {
+	var fetchCount atomic.Int32
+	mockFetcher := &mockSpotPriceHistoryFetcher{
+		fetchFunc: func(key SpotPriceHistoryKey) (*SpotPriceHistoryEntry, error) {
+			fetchCount.Add(1)
+			// Simulate slow API call to increase chance of concurrent access
+			time.Sleep(50 * time.Millisecond)
+			return &SpotPriceHistoryEntry{
+				SpotPrice:   0.07,
+				Timestamp:   time.Now(),
+				RetrievedAt: time.Now(),
+			}, nil
+		},
+	}
+	cache := NewSpotPriceHistoryCache(mockFetcher)
+
+	const goroutines = 10
+	var wg sync.WaitGroup
+	wg.Add(goroutines)
+	for i := 0; i < goroutines; i++ {
+		go func() {
+			defer wg.Done()
+			entry, err := cache.GetSpotPrice("us-west-2", "m5.large", "us-west-2a")
+			if err != nil {
+				t.Errorf("Expected no error, got %v", err)
+			}
+			if entry.SpotPrice != 0.07 {
+				t.Errorf("Expected spot price 0.07, got %f", entry.SpotPrice)
+			}
+		}()
+	}
+	wg.Wait()
+
+	if count := fetchCount.Load(); count != 1 {
+		t.Errorf("Expected exactly 1 fetch call, got %d", count)
+	}
+}
+
 func TestSpotPriceHistoryCache_GetSpotPrice_StaleEntry(t *testing.T) {
 	fetchCalled := false
 	mockFetcher := &mockSpotPriceHistoryFetcher{