Explorar o código

Merge pull request #1486 from yeya24/useragent

Support setting user agent when querying metrics store
Ajay Tripathy %!s(int64=3) %!d(string=hai) anos
pai
achega
6094544521

+ 6 - 30
pkg/cloud/gcpprovider.go

@@ -76,16 +76,6 @@ var (
 	gceRegex = regexp.MustCompile("gce://([^/]*)/*")
 )
 
-type userAgentTransport struct {
-	userAgent string
-	base      http.RoundTripper
-}
-
-func (t userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
-	req.Header.Set("User-Agent", t.userAgent)
-	return t.base.RoundTrip(req)
-}
-
 // GCP implements a provider interface for GCP
 type GCP struct {
 	Pricing                 map[string]*GCPPricing
@@ -99,6 +89,7 @@ type GCP struct {
 	Config                  *ProviderConfig
 	ServiceKeyProvided      bool
 	ValidPricingKeys        map[string]bool
+	metadataClient          *metadata.Client
 	clusterManagementPrice  float64
 	clusterProjectId        string
 	clusterRegion           string
@@ -310,12 +301,7 @@ func (gcp *GCP) UpdateConfig(r io.Reader, updateType string) (*CustomPricing, er
 func (gcp *GCP) ClusterInfo() (map[string]string, error) {
 	remoteEnabled := env.IsRemoteEnabled()
 
-	metadataClient := metadata.NewClient(&http.Client{Transport: userAgentTransport{
-		userAgent: "kubecost",
-		base:      http.DefaultTransport,
-	}})
-
-	attribute, err := metadataClient.InstanceAttributeValue("cluster-name")
+	attribute, err := gcp.metadataClient.InstanceAttributeValue("cluster-name")
 	if err != nil {
 		log.Infof("Error loading metadata cluster-name: %s", err.Error())
 	}
@@ -348,13 +334,8 @@ func (gcp *GCP) ClusterManagementPricing() (string, float64, error) {
 	return gcp.clusterProvisioner, gcp.clusterManagementPrice, nil
 }
 
-func (*GCP) GetAddresses() ([]byte, error) {
-	// metadata API setup
-	metadataClient := metadata.NewClient(&http.Client{Transport: userAgentTransport{
-		userAgent: "kubecost",
-		base:      http.DefaultTransport,
-	}})
-	projID, err := metadataClient.ProjectID()
+func (gcp *GCP) GetAddresses() ([]byte, error) {
+	projID, err := gcp.metadataClient.ProjectID()
 	if err != nil {
 		return nil, err
 	}
@@ -377,13 +358,8 @@ func (*GCP) GetAddresses() ([]byte, error) {
 }
 
 // GetDisks returns the GCP disks backing PVs. Useful because sometimes k8s will not clean up PVs correctly. Requires a json config in /var/configs with key region.
-func (*GCP) GetDisks() ([]byte, error) {
-	// metadata API setup
-	metadataClient := metadata.NewClient(&http.Client{Transport: userAgentTransport{
-		userAgent: "kubecost",
-		base:      http.DefaultTransport,
-	}})
-	projID, err := metadataClient.ProjectID()
+func (gcp *GCP) GetDisks() ([]byte, error) {
+	projID, err := gcp.metadataClient.ProjectID()
 	if err != nil {
 		return nil, err
 	}

+ 5 - 0
pkg/cloud/provider.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net/http"
 	"regexp"
 	"strconv"
 	"strings"
@@ -21,6 +22,7 @@ import (
 	"github.com/opencost/opencost/pkg/config"
 	"github.com/opencost/opencost/pkg/env"
 	"github.com/opencost/opencost/pkg/log"
+	"github.com/opencost/opencost/pkg/util/httputil"
 	"github.com/opencost/opencost/pkg/util/watcher"
 
 	v1 "k8s.io/api/core/v1"
@@ -464,6 +466,9 @@ func NewProvider(cache clustercache.ClusterCache, apiKey string, config *config.
 			Config:           NewProviderConfig(config, cp.configFileName),
 			clusterRegion:    cp.region,
 			clusterProjectId: cp.projectID,
+			metadataClient: metadata.NewClient(&http.Client{
+				Transport: httputil.NewUserAgentTransport("kubecost", http.DefaultTransport),
+			}),
 		}, nil
 	case kubecost.AWSProvider:
 		log.Info("Found ProviderID starting with \"aws\", using AWS Provider")

+ 16 - 12
pkg/prom/prom.go

@@ -16,12 +16,15 @@ import (
 	"github.com/opencost/opencost/pkg/log"
 	"github.com/opencost/opencost/pkg/util/fileutil"
 	"github.com/opencost/opencost/pkg/util/httputil"
+	"github.com/opencost/opencost/pkg/version"
 
 	golog "log"
 
 	prometheus "github.com/prometheus/client_golang/api"
 )
 
+var UserAgent = fmt.Sprintf("Opencost/%s", version.Version)
+
 //--------------------------------------------------------------------------
 //  QueryParamsDecorator
 //--------------------------------------------------------------------------
@@ -355,19 +358,20 @@ type PrometheusClientConfig struct {
 // NewPrometheusClient creates a new rate limited client which limits by outbound concurrent requests.
 func NewPrometheusClient(address string, config *PrometheusClientConfig) (prometheus.Client, error) {
 	// may be necessary for long prometheus queries
-	pc := prometheus.Config{
-		Address: address,
-		RoundTripper: &http.Transport{
-			Proxy: http.ProxyFromEnvironment,
-			DialContext: (&net.Dialer{
-				Timeout:   config.Timeout,
-				KeepAlive: config.KeepAlive,
-			}).DialContext,
-			TLSHandshakeTimeout: config.TLSHandshakeTimeout,
-			TLSClientConfig: &tls.Config{
-				InsecureSkipVerify: config.TLSInsecureSkipVerify,
-			},
+	rt := httputil.NewUserAgentTransport(UserAgent, &http.Transport{
+		Proxy: http.ProxyFromEnvironment,
+		DialContext: (&net.Dialer{
+			Timeout:   config.Timeout,
+			KeepAlive: config.KeepAlive,
+		}).DialContext,
+		TLSHandshakeTimeout: config.TLSHandshakeTimeout,
+		TLSClientConfig: &tls.Config{
+			InsecureSkipVerify: config.TLSInsecureSkipVerify,
 		},
+	})
+	pc := prometheus.Config{
+		Address:      address,
+		RoundTripper: rt,
 	}
 
 	client, err := prometheus.NewClient(pc)

+ 1 - 1
pkg/prom/query.go

@@ -95,7 +95,7 @@ func (ctx *Context) Query(query string) QueryResultsChan {
 	return resCh
 }
 
-// QueryWithTime returns a QueryResultsChan, then runs the given query at the
+// QueryAtTime returns a QueryResultsChan, then runs the given query at the
 // given time (see time parameter here: https://prometheus.io/docs/prometheus/latest/querying/api/#instant-queries)
 // and sends the results on the provided channel. Receiver is responsible for
 // closing the channel, preferably using the Read method.

+ 31 - 0
pkg/util/httputil/roundtrip.go

@@ -0,0 +1,31 @@
+package httputil
+
+import "net/http"
+
+type userAgentTransport struct {
+	userAgent string
+	base      http.RoundTripper
+}
+
+// NewUserAgentTransport creates a RoundTripper that attaches the configured user agent.
+func NewUserAgentTransport(userAgent string, base http.RoundTripper) http.RoundTripper {
+	return &userAgentTransport{
+		userAgent: userAgent,
+		base:      base,
+	}
+}
+
+func (t userAgentTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+	// The specification of http.RoundTripper says that it shouldn't mutate
+	// the request so make a copy of req.Header since this is all that is
+	// modified.
+	r2 := new(http.Request)
+	*r2 = *r
+	r2.Header = make(http.Header)
+	for k, s := range r.Header {
+		r2.Header[k] = s
+	}
+	r2.Header.Set("User-Agent", t.userAgent)
+	r = r2
+	return t.base.RoundTrip(r)
+}

+ 57 - 0
pkg/util/httputil/roundtrip_test.go

@@ -0,0 +1,57 @@
+package httputil
+
+import (
+	"fmt"
+	"net/http"
+	"reflect"
+	"testing"
+)
+
+type reqValidateRoundTripper struct {
+	expectedReq *http.Request
+}
+
+func (rt *reqValidateRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
+	if !reflect.DeepEqual(r, rt.expectedReq) {
+		return nil, fmt.Errorf("expected req %v, got %v", rt.expectedReq, r)
+	}
+	return nil, nil
+}
+
+func TestUserAgentTransport(t *testing.T) {
+	for _, tc := range []struct {
+		name   string
+		ua     string
+		req    *http.Request
+		expReq *http.Request
+	}{
+		{
+			name:   "opencost",
+			ua:     "opencost",
+			req:    &http.Request{},
+			expReq: &http.Request{Header: http.Header{"User-Agent": []string{"opencost"}}},
+		},
+		{
+			name:   "foo",
+			ua:     "foo",
+			req:    &http.Request{},
+			expReq: &http.Request{Header: http.Header{"User-Agent": []string{"foo"}}},
+		},
+		{
+			name:   "overwrite user agent if exists",
+			ua:     "opencost",
+			req:    &http.Request{Header: http.Header{"User-Agent": []string{"foo"}}},
+			expReq: &http.Request{Header: http.Header{"User-Agent": []string{"opencost"}}},
+		},
+	} {
+		t.Run(tc.name, func(t *testing.T) {
+			rt := NewUserAgentTransport(tc.ua, &reqValidateRoundTripper{
+				expectedReq: tc.expReq,
+			})
+			_, err := rt.RoundTrip(tc.req)
+			if err != nil {
+				t.Error(err)
+			}
+		})
+	}
+}