Преглед изворни кода

fix(cloud): add timeouts to pricing HTTP clients (#3835)

Signed-off-by: Tushar Verma <tusharmyself06@gmail.com>
Tushar-Verma пре 7 часа
родитељ
комит
5a8bb3d397

+ 5 - 1
pkg/cloud/aws/provider.go

@@ -17,6 +17,7 @@ import (
 	"time"
 
 	"github.com/aws/smithy-go"
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 	"github.com/opencost/opencost/pkg/cloud/models"
 	"github.com/opencost/opencost/pkg/cloud/utils"
 
@@ -858,7 +859,10 @@ func (aws *AWS) getRegionPricing(nodeList []*clustercache.Node) (*http.Response,
 	}
 
 	log.Infof("starting download of \"%s\", which is quite large ...", pricingURL)
-	resp, err := http.Get(pricingURL)
+	// This file is large and can take a while to stream, so the streaming client
+	// bounds connect/TLS/response-header time but not the total body read - enough
+	// to bail on a hung endpoint without truncating a legitimate slow download.
+	resp, err := httputil.StreamingGet(context.Background(), pricingURL)
 	if err != nil {
 		log.Errorf("Bogus fetch of \"%s\": %v", pricingURL, err)
 		return nil, pricingURL, err

+ 5 - 1
pkg/cloud/azure/pricesheetdownloader.go

@@ -18,6 +18,7 @@ import (
 	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
 
 	"github.com/opencost/opencost/core/pkg/log"
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 )
 
 type PriceSheetDownloader struct {
@@ -79,7 +80,10 @@ func (d PriceSheetDownloader) saveData(ctx context.Context, url, tempName string
 		return nil, fmt.Errorf("creating %s temp file: %w", tempName, err)
 	}
 
-	resp, err := http.Get(url)
+	// The price sheet can be large, so the streaming client bounds connect/TLS/
+	// response-header time but not the body read, avoiding truncation of a slow
+	// download. Pass the caller's context so the download is cancelable.
+	resp, err := httputil.StreamingGet(ctx, url)
 	if err != nil {
 		return nil, fmt.Errorf("downloading: %w", err)
 	}

+ 5 - 1
pkg/cloud/azure/provider.go

@@ -29,6 +29,7 @@ import (
 	"github.com/opencost/opencost/core/pkg/util/fileutil"
 	"github.com/opencost/opencost/core/pkg/util/json"
 	"github.com/opencost/opencost/core/pkg/util/timeutil"
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 	"github.com/opencost/opencost/pkg/cloud/models"
 	"github.com/opencost/opencost/pkg/cloud/utils"
 	"github.com/opencost/opencost/pkg/env"
@@ -302,7 +303,10 @@ func getRetailPrice(region string, skuName string, currencyCode string, spot boo
 	pricingURL := buildAzureRetailPricesURL(region, skuName, currencyCode)
 	log.Infof("starting download retail price payload from \"%s\"", pricingURL)
 
-	resp, err := http.Get(pricingURL)
+	// Single SKU lookup returns a small payload, so the shared bounded client
+	// keeps a hung endpoint from blocking pricing without risking truncation.
+	client := httputil.BoundedClient()
+	resp, err := client.Get(pricingURL)
 	if err != nil {
 		return "", fmt.Errorf("failed to fetch retail price with URL \"%s\": %v", pricingURL, err)
 	}

+ 6 - 1
pkg/cloud/gcp/provider.go

@@ -17,6 +17,7 @@ import (
 
 	coreenv "github.com/opencost/opencost/core/pkg/env"
 	"github.com/opencost/opencost/pkg/cloud/aws"
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 	"github.com/opencost/opencost/pkg/cloud/models"
 	"github.com/opencost/opencost/pkg/cloud/utils"
 
@@ -979,7 +980,9 @@ func (gcp *GCP) getBillingAPIClientAndURL(apiKey, currencyCode string) (*http.Cl
 	url := gcp.buildBillingAPIURL(apiKey, currencyCode)
 
 	if apiKey != "" {
-		return http.DefaultClient, url.String(), nil
+		// Shared client carries a request timeout so a hung billing endpoint
+		// can't block the pricing refresh.
+		return httputil.BoundedClient(), url.String(), nil
 	}
 
 	googleHttpClient, err := google.DefaultClient(context.TODO(), GCPCloudOAuthScope)
@@ -987,6 +990,8 @@ func (gcp *GCP) getBillingAPIClientAndURL(apiKey, currencyCode string) (*http.Cl
 		log.Errorf("GCP Billing API: Workload Identity detected but failed to create authenticated client: %v", err)
 		return nil, "", err
 	}
+	// google.DefaultClient has no timeout by default; bound it to match the keyed path.
+	googleHttpClient.Timeout = httputil.PricingTimeout
 
 	return googleHttpClient, url.String(), nil
 }

+ 3 - 2
pkg/cloud/gcp/provider_test.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"encoding/json"
 	"fmt"
-	"net/http"
 	"net/url"
 	"os"
 	"reflect"
@@ -14,6 +13,7 @@ import (
 
 	"github.com/google/martian/log"
 	"github.com/opencost/opencost/core/pkg/clustercache"
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 	"github.com/opencost/opencost/pkg/cloud/models"
 	"github.com/opencost/opencost/pkg/config"
 	"github.com/stretchr/testify/assert"
@@ -758,7 +758,8 @@ func TestGCP_getBillingAPIClientAndURL(t *testing.T) {
 	client, rawURL, err := gcp.getBillingAPIClientAndURL("test-key", "USD")
 
 	assert.NoError(t, err)
-	assert.Equal(t, http.DefaultClient, client)
+	assert.NotNil(t, client)
+	assert.Equal(t, httputil.PricingTimeout, client.Timeout)
 
 	parsedURL, err := url.Parse(rawURL)
 	assert.NoError(t, err)

+ 78 - 0
pkg/cloud/httputil/httputil.go

@@ -0,0 +1,78 @@
+// Package httputil provides shared HTTP clients for cloud pricing ingestion.
+//
+// The default net/http client has no timeout, so a hung or unreachable provider
+// pricing endpoint can block pricing refresh indefinitely. These helpers return
+// clients with sensible timeouts. They are shared and reused so a new connection
+// pool is not created on every pricing fetch.
+package httputil
+
+import (
+	"context"
+	"net"
+	"net/http"
+	"time"
+)
+
+// PricingTimeout is applied to pricing HTTP requests. For the bounded client it
+// caps the whole request; for the streaming client it caps the wait for
+// response headers only (not the body read).
+const PricingTimeout = 30 * time.Second
+
+var (
+	boundedClient   = &http.Client{Timeout: PricingTimeout}
+	streamingClient = newStreamingClient(http.DefaultTransport)
+)
+
+// BoundedClient returns a shared http.Client with a total request timeout,
+// suitable for small or bounded pricing API responses.
+func BoundedClient() *http.Client {
+	return boundedClient
+}
+
+// StreamingClient returns a shared http.Client for large pricing downloads (for
+// example the AWS pricing file or the Azure price sheet). It bounds the connect,
+// TLS handshake, and response-header wait, but not the total body read, so a
+// legitimately large or slow download is not truncated while a hung endpoint is
+// still abandoned.
+func StreamingClient() *http.Client {
+	return streamingClient
+}
+
+// StreamingGet issues a GET for a large download using the streaming client and
+// the caller's context, so the request is cancelable. It centralizes the
+// context-aware request construction shared by the large-download paths (the AWS
+// pricing file and the Azure price sheet). Callers without a context of their
+// own can pass context.Background().
+func StreamingGet(ctx context.Context, url string) (*http.Response, error) {
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+	return StreamingClient().Do(req)
+}
+
+// newStreamingClient builds the streaming client from a base RoundTripper. It
+// takes the base transport as a parameter so the fallback path can be exercised
+// by tests; production passes http.DefaultTransport.
+func newStreamingClient(base http.RoundTripper) *http.Client {
+	// Clone the base transport when possible so we keep its dial and TLS
+	// handshake timeouts. Guard the type assertion: http.DefaultTransport is
+	// declared as a RoundTripper and can be replaced (e.g. in tests), in which
+	// case we fall back to a fresh transport rather than panicking.
+	transport, ok := base.(*http.Transport)
+	if ok {
+		transport = transport.Clone()
+	} else {
+		// base is not a *http.Transport, so we can't clone its timeouts. Build a
+		// fresh transport that still bounds dial and TLS handshake time, matching
+		// the guarantee in StreamingClient's doc.
+		transport = &http.Transport{
+			Proxy:                 http.ProxyFromEnvironment,
+			DialContext:           (&net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second}).DialContext,
+			TLSHandshakeTimeout:   10 * time.Second,
+			ExpectContinueTimeout: 1 * time.Second,
+		}
+	}
+	transport.ResponseHeaderTimeout = PricingTimeout
+	return &http.Client{Transport: transport}
+}

+ 130 - 0
pkg/cloud/httputil/httputil_test.go

@@ -0,0 +1,130 @@
+package httputil
+
+import (
+	"context"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+)
+
+func TestBoundedClientHasTotalTimeout(t *testing.T) {
+	c := BoundedClient()
+	if c.Timeout != PricingTimeout {
+		t.Fatalf("expected total timeout %v, got %v", PricingTimeout, c.Timeout)
+	}
+}
+
+func TestBoundedClientIsShared(t *testing.T) {
+	if BoundedClient() != BoundedClient() {
+		t.Fatal("expected BoundedClient to return a shared instance")
+	}
+}
+
+func TestStreamingClientHasNoTotalTimeout(t *testing.T) {
+	c := StreamingClient()
+	if c.Timeout != 0 {
+		t.Fatalf("streaming client must not set a total timeout, got %v", c.Timeout)
+	}
+}
+
+func TestStreamingClientHasResponseHeaderTimeout(t *testing.T) {
+	c := StreamingClient()
+	tr, ok := c.Transport.(*http.Transport)
+	if !ok {
+		t.Fatalf("expected *http.Transport, got %T", c.Transport)
+	}
+	if tr.ResponseHeaderTimeout != PricingTimeout {
+		t.Fatalf("expected response-header timeout %v, got %v", PricingTimeout, tr.ResponseHeaderTimeout)
+	}
+}
+
+func TestStreamingClientIsShared(t *testing.T) {
+	if StreamingClient() != StreamingClient() {
+		t.Fatal("expected StreamingClient to return a shared instance")
+	}
+}
+
+type roundTripperFunc func(*http.Request) (*http.Response, error)
+
+func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
+	return f(r)
+}
+
+// A base transport that is not a *http.Transport must not panic and must still
+// yield a usable client with the response-header timeout applied.
+func TestNewStreamingClientFallsBackWhenNotTransport(t *testing.T) {
+	// Honor the RoundTripper contract (non-nil response when error is nil), even
+	// though this base is only used to exercise the fallback and never round-trips.
+	base := roundTripperFunc(func(*http.Request) (*http.Response, error) {
+		return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
+	})
+	c := newStreamingClient(base)
+	tr, ok := c.Transport.(*http.Transport)
+	if !ok {
+		t.Fatalf("expected fallback *http.Transport, got %T", c.Transport)
+	}
+	if tr.ResponseHeaderTimeout != PricingTimeout {
+		t.Fatalf("expected response-header timeout %v, got %v", PricingTimeout, tr.ResponseHeaderTimeout)
+	}
+	// The fallback transport must still bound TLS handshake time.
+	if tr.TLSHandshakeTimeout == 0 {
+		t.Fatal("expected fallback transport to set a TLS handshake timeout")
+	}
+}
+
+func TestStreamingGetReturnsBody(t *testing.T) {
+	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+		_, _ = w.Write([]byte("price-data"))
+	}))
+	defer srv.Close()
+
+	resp, err := StreamingGet(context.Background(), srv.URL)
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	defer resp.Body.Close()
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatalf("reading body: %v", err)
+	}
+	if string(body) != "price-data" {
+		t.Fatalf("expected body %q, got %q", "price-data", string(body))
+	}
+}
+
+func TestStreamingGetHonorsCanceledContext(t *testing.T) {
+	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+		_, _ = w.Write([]byte("ok"))
+	}))
+	defer srv.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	cancel() // cancel before the request runs
+
+	if _, err := StreamingGet(ctx, srv.URL); err == nil {
+		t.Fatal("expected an error from a canceled context, got nil")
+	}
+}
+
+func TestStreamingGetRejectsBadURL(t *testing.T) {
+	if _, err := StreamingGet(context.Background(), "://not-a-url"); err == nil {
+		t.Fatal("expected an error building the request, got nil")
+	}
+}
+
+// Cloning must not mutate the shared default transport.
+func TestNewStreamingClientClonesWithoutMutatingDefault(t *testing.T) {
+	def, ok := http.DefaultTransport.(*http.Transport)
+	if !ok {
+		t.Skip("default transport is not *http.Transport in this environment")
+	}
+	c := newStreamingClient(def)
+	tr := c.Transport.(*http.Transport)
+	if tr == def {
+		t.Fatal("expected a cloned transport, got the shared default")
+	}
+	if def.ResponseHeaderTimeout == PricingTimeout {
+		t.Fatal("mutated the shared default transport's ResponseHeaderTimeout")
+	}
+}

+ 5 - 2
pkg/cloud/oracle/ratecard.go

@@ -8,6 +8,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 	"github.com/opencost/opencost/pkg/cloud/models"
 )
 
@@ -52,8 +53,10 @@ func NewRateCardStore(url, currencyCode string) *RateCardStore {
 	return &RateCardStore{
 		url:          url,
 		currencyCode: currencyCode,
-		client:       &http.Client{},
-		prices:       map[string]Price{},
+		// Zero-value http.Client has no timeout; use the shared bounded client so
+		// a stalled rate-card endpoint can't hang ingestion.
+		client: httputil.BoundedClient(),
+		prices: map[string]Price{},
 	}
 }
 

+ 4 - 1
pkg/cloud/otc/pricingapi.go

@@ -7,9 +7,12 @@ import (
 	"net/http"
 
 	"github.com/opencost/opencost/core/pkg/log"
+	"github.com/opencost/opencost/pkg/cloud/httputil"
 )
 
-var otcHTTPClient = http.DefaultClient
+// http.DefaultClient has no timeout, so a hung pricing endpoint would block the
+// paginated fetch loop forever. Use the shared bounded client instead.
+var otcHTTPClient = httputil.BoundedClient()
 
 // Fetches and flattens all product entries across multiple services with pagination
 func (otc *OTC) fetchPaginatedProducts(serviceNames []string) ([]Product, error) {