浏览代码

Merge pull request #1214 from kubecost/bolt/monitoring

Cost-Model Monitoring and Diagnostic Enhancements
Matt Bolt 4 年之前
父节点
当前提交
68becaa698

+ 1 - 0
go.mod

@@ -89,6 +89,7 @@ require (
 	github.com/jstemmer/go-junit-report v0.9.1 // indirect
 	github.com/klauspost/compress v1.13.5 // indirect
 	github.com/klauspost/cpuid v1.3.1 // indirect
+	github.com/kubecost/events v0.0.3 // indirect
 	github.com/magiconair/properties v1.8.5 // indirect
 	github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
 	github.com/minio/md5-simd v1.1.0 // indirect

+ 2 - 0
go.sum

@@ -395,6 +395,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kubecost/events v0.0.3 h1:q0Hn8DsovzW53T2oSRfZF92JDwlmAAQt0aktl4ccm74=
+github.com/kubecost/events v0.0.3/go.mod h1:i3DyCVatehxq6tAbvBrARuafjkX2DECPk9OWxiaRIhY=
 github.com/labstack/echo/v4 v4.1.11/go.mod h1:i541M3Fj6f76NZtHSj7TXnyM8n2gaodfvfxNnFqi74g=
 github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
 github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=

+ 3 - 1
pkg/cmd/agent/agent.go

@@ -14,6 +14,7 @@ import (
 	"github.com/kubecost/cost-model/pkg/env"
 	"github.com/kubecost/cost-model/pkg/kubeconfig"
 	"github.com/kubecost/cost-model/pkg/log"
+	"github.com/kubecost/cost-model/pkg/metrics"
 	"github.com/kubecost/cost-model/pkg/prom"
 	"github.com/kubecost/cost-model/pkg/util/watcher"
 
@@ -216,7 +217,8 @@ func Execute(opts *AgentOpts) error {
 	rootMux := http.NewServeMux()
 	rootMux.HandleFunc("/healthz", Healthz)
 	rootMux.Handle("/metrics", promhttp.Handler())
-	handler := cors.AllowAll().Handler(rootMux)
+	telemetryHandler := metrics.ResponseMetricMiddleware(rootMux)
+	handler := cors.AllowAll().Handler(telemetryHandler)
 
 	return http.ListenAndServe(fmt.Sprintf(":%d", env.GetKubecostMetricsPort()), handler)
 }

+ 3 - 1
pkg/cmd/costmodel/costmodel.go

@@ -6,6 +6,7 @@ import (
 	"github.com/julienschmidt/httprouter"
 	"github.com/kubecost/cost-model/pkg/costmodel"
 	"github.com/kubecost/cost-model/pkg/errors"
+	"github.com/kubecost/cost-model/pkg/metrics"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"github.com/rs/cors"
 )
@@ -29,7 +30,8 @@ func Execute(opts *CostModelOpts) error {
 	a.Router.GET("/allocation/summary", a.ComputeAllocationHandlerSummary)
 	rootMux.Handle("/", a.Router)
 	rootMux.Handle("/metrics", promhttp.Handler())
-	handler := cors.AllowAll().Handler(rootMux)
+	telemetryHandler := metrics.ResponseMetricMiddleware(rootMux)
+	handler := cors.AllowAll().Handler(telemetryHandler)
 
 	return http.ListenAndServe(":9003", errors.PanicHandlerMiddleware(handler))
 }

+ 2 - 0
pkg/costmodel/metrics.go

@@ -346,6 +346,8 @@ func NewCostModelMetricsEmitter(promClient promclient.Client, clusterCache clust
 		EmitKubeStateMetricsV1Only:    env.IsEmitKsmV1MetricsOnly(),
 	})
 
+	metrics.InitKubecostTelemetry(metricsConfig)
+
 	return &CostModelMetricsEmitter{
 		PrometheusClient:              promClient,
 		KubeClusterCache:              clusterCache,

+ 3 - 11
pkg/costmodel/router.go

@@ -882,23 +882,15 @@ func (a *Accesses) GetPrometheusMetrics(w http.ResponseWriter, _ *http.Request,
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Access-Control-Allow-Origin", "*")
 
-	promMetrics, err := prom.GetPrometheusMetrics(a.PrometheusClient, "")
-	if err != nil {
-		w.Write(WrapData(nil, err))
-		return
-	}
+	promMetrics := prom.GetPrometheusMetrics(a.PrometheusClient, "")
 
 	result := map[string][]*prom.PrometheusDiagnostic{
 		"prometheus": promMetrics,
 	}
 
 	if thanos.IsEnabled() {
-		thanosMetrics, err := prom.GetPrometheusMetrics(a.ThanosClient, thanos.QueryOffset())
-		if err != nil {
-			log.Warnf("Error getting Thanos queue state: %s", err)
-		} else {
-			result["thanos"] = thanosMetrics
-		}
+		thanosMetrics := prom.GetPrometheusMetrics(a.ThanosClient, thanos.QueryOffset())
+		result["thanos"] = thanosMetrics
 	}
 
 	w.Write(WrapData(result, nil))

+ 0 - 7
pkg/kubecost/common.go

@@ -14,10 +14,3 @@ func NewPair[T any, U any](first T, second U) Pair[T, U] {
 		Second: second,
 	}
 }
-
-// DefaultValue[T] returns the default value for any generic type. This is helpful for generic
-// types where a type parameter can be a value type or pointer.
-func DefaultValue[T any]() T {
-	var t T
-	return t
-}

+ 12 - 0
pkg/metrics/events.go

@@ -0,0 +1,12 @@
+package metrics
+
+import "time"
+
+// HttpHandlerMetricEvent contains http handler response metrics.
+type HttpHandlerMetricEvent struct {
+	Handler      string
+	Code         int
+	Method       string
+	ResponseTime time.Duration
+	ResponseSize uint64
+}

+ 80 - 0
pkg/metrics/httpmetricmiddleware.go

@@ -0,0 +1,80 @@
+package metrics
+
+import (
+	"fmt"
+	"net/http"
+	"time"
+
+	"github.com/kubecost/events"
+)
+
+// ResponseMetricMiddleware dispatches metric events for handles request and responses.
+func ResponseMetricMiddleware(handler http.Handler) http.Handler {
+	dispatcher := events.GlobalDispatcherFor[HttpHandlerMetricEvent]()
+
+	return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+		// use a ResponseWriter implementation to record telemetry for the response
+		respWriter := &responseWriterAdapter{w: rw}
+
+		// record method and path of the request
+		method := r.Method
+		path := r.URL.Path
+
+		// time and execute the handler
+		start := time.Now()
+		handler.ServeHTTP(respWriter, r)
+		duration := time.Since(start)
+
+		// record the response code and size
+		code := respWriter.StatusCode()
+		size := respWriter.TotalResponseSize()
+
+		dispatcher.Dispatch(HttpHandlerMetricEvent{
+			Handler:      path,
+			Method:       method,
+			Code:         code,
+			ResponseTime: duration,
+			ResponseSize: size,
+		})
+
+	})
+}
+
+// responseWriterAdapter implements http.ResponseWriter and extracts the statusCode.
+type responseWriterAdapter struct {
+	w          http.ResponseWriter
+	written    bool
+	statusCode int
+	size       uint64
+}
+
+func (wd *responseWriterAdapter) Header() http.Header {
+	return wd.w.Header()
+}
+
+func (wd *responseWriterAdapter) Write(bytes []byte) (int, error) {
+	numBytes, err := wd.w.Write(bytes)
+	wd.size += uint64(numBytes)
+	return numBytes, err
+}
+
+func (wd *responseWriterAdapter) WriteHeader(statusCode int) {
+	wd.written = true
+	wd.statusCode = statusCode
+	wd.w.WriteHeader(statusCode)
+}
+
+func (wd *responseWriterAdapter) StatusCode() int {
+	if !wd.written {
+		return http.StatusOK
+	}
+	return wd.statusCode
+}
+
+func (wd *responseWriterAdapter) Status() string {
+	return fmt.Sprintf("%d", wd.StatusCode())
+}
+
+func (wd *responseWriterAdapter) TotalResponseSize() uint64 {
+	return wd.size
+}

+ 61 - 0
pkg/metrics/telemetry.go

@@ -0,0 +1,61 @@
+package metrics
+
+import (
+	"fmt"
+	"sync"
+
+	"github.com/kubecost/events"
+	"github.com/prometheus/client_golang/prometheus"
+)
+
+var (
+	once       sync.Once
+	dispatcher events.Dispatcher[HttpHandlerMetricEvent]
+	// -- append new dispatchers here for new event types
+
+	// prometheus metrics
+	requestsCount *prometheus.CounterVec
+	responseTime  *prometheus.HistogramVec
+	responseSize  *prometheus.SummaryVec
+)
+
+// InitKubecostTelemetry registers kubecost application telemetry.
+func InitKubecostTelemetry(config *MetricsConfig) {
+	// TODO(bolt): Check MetricsConfig for disabled metrics
+
+	once.Do(func() {
+		// register prometheus metrics
+		requestsCount = prometheus.NewCounterVec(prometheus.CounterOpts{
+			Name: "kubecost_http_requests_total",
+			Help: "kubecost_http_requests_total Total number of HTTP requests",
+		}, []string{"handler", "method", "code"})
+
+		var buckets = []float64{0.001, 0.01, 0.1, 0.3, 0.6, 1, 3, 6, 9, 20, 30, 60, 90, 120, 240, 360, 720}
+		responseTime = prometheus.NewHistogramVec(prometheus.HistogramOpts{
+			Name:    "kubecost_http_response_time_seconds",
+			Help:    "kubecost_http_response_time_seconds Response time in seconds",
+			Buckets: buckets,
+		}, []string{"handler", "method", "code"})
+
+		responseSize = prometheus.NewSummaryVec(prometheus.SummaryOpts{
+			Name: "kubecost_http_response_size_bytes",
+			Help: "kubecost_http_response_size_bytes Response size in bytes",
+		}, []string{"handler", "method", "code"})
+
+		prometheus.MustRegister(requestsCount, responseTime, responseSize)
+
+		// register event listeners
+		dispatcher = events.GlobalDispatcherFor[HttpHandlerMetricEvent]()
+		dispatcher.AddEventHandler(onHttpHandlerMetricEvent)
+		// -- append new event handlers here
+	})
+}
+
+// onHttpHandlerMetricEvent handles all incoming HttpHandlerMetricEvents
+func onHttpHandlerMetricEvent(event HttpHandlerMetricEvent) {
+	code := fmt.Sprintf("%d", event.Code)
+
+	requestsCount.WithLabelValues(event.Handler, event.Method, code).Inc()
+	responseSize.WithLabelValues(event.Handler, event.Method, code).Observe(float64(event.ResponseSize))
+	responseTime.WithLabelValues(event.Handler, event.Method, code).Observe(event.ResponseTime.Seconds())
+}

+ 164 - 61
pkg/prom/diagnostics.go

@@ -9,6 +9,95 @@ import (
 	prometheus "github.com/prometheus/client_golang/api"
 )
 
+// Prometheus Metric Diagnostic IDs
+const (
+	// CAdvisorDiagnosticMetricID is the identifier of the metric used to determine if cAdvisor is being scraped.
+	CAdvisorDiagnosticMetricID = "cadvisorMetric"
+
+	// CAdvisorLabelDiagnosticMetricID is the identifier of the metric used to determine if cAdvisor labels are correct.
+	CAdvisorLabelDiagnosticMetricID = "cadvisorLabel"
+
+	// KSMDiagnosticMetricID is the identifier for the metric used to determine if KSM metrics are being scraped.
+	KSMDiagnosticMetricID = "ksmMetric"
+
+	// KSMVersionDiagnosticMetricID is the identifier for the metric used to determine if KSM version is correct.
+	KSMVersionDiagnosticMetricID = "ksmVersion"
+
+	// KubecostDiagnosticMetricID is the identifier for the metric used to determine if Kubecost metrics are being scraped.
+	KubecostDiagnosticMetricID = "kubecostMetric"
+
+	// NodeExporterDiagnosticMetricID is the identifier for the metric used to determine if NodeExporter metrics are being scraped.
+	NodeExporterDiagnosticMetricID = "neMetric"
+
+	// ScrapeIntervalDiagnosticMetricID is the identifier for the metric used to determine if prometheus has its own self-scraped
+	// metrics.
+	ScrapeIntervalDiagnosticMetricID = "scrapeInterval"
+
+	// CPUThrottlingDiagnosticMetricID is the identifier for the metric used to determine if CPU throttling is being applied to the
+	// cost-model container.
+	CPUThrottlingDiagnosticMetricID = "cpuThrottling"
+)
+
+const DocumentationBaseURL = "https://github.com/kubecost/docs/blob/master/diagnostics.md"
+
+// diagnostic definitions mapping holds all of the diagnostic definitions that can be used for prometheus metrics diagnostics
+var diagnosticDefinitions map[string]*diagnosticDefinition = map[string]*diagnosticDefinition{
+	CAdvisorDiagnosticMetricID: {
+		ID:          CAdvisorDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(container_cpu_usage_seconds_total[5m] %s)`,
+		Label:       "cAdvsior metrics available",
+		Description: "Determine if cAdvisor metrics are available during last 5 minutes.",
+		DocLink:     fmt.Sprintf("%s#cadvisor-metrics-available", DocumentationBaseURL),
+	},
+	KSMDiagnosticMetricID: {
+		ID:          KSMDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(kube_pod_container_resource_requests{resource="memory", unit="byte"}[5m] %s)`,
+		Label:       "Kube-state-metrics available",
+		Description: "Determine if metrics from kube-state-metrics are available during last 5 minutes.",
+		DocLink:     fmt.Sprintf("%s#kube-state-metrics-metrics-available", DocumentationBaseURL),
+	},
+	KubecostDiagnosticMetricID: {
+		ID:          KubecostDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(node_cpu_hourly_cost[5m] %s)`,
+		Label:       "Kubecost metrics available",
+		Description: "Determine if metrics from Kubecost are available during last 5 minutes.",
+	},
+	NodeExporterDiagnosticMetricID: {
+		ID:          NodeExporterDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(node_cpu_seconds_total[5m] %s)`,
+		Label:       "Node-exporter metrics available",
+		Description: "Determine if metrics from node-exporter are available during last 5 minutes.",
+		DocLink:     fmt.Sprintf("%s#node-exporter-metrics-available", DocumentationBaseURL),
+	},
+	CAdvisorLabelDiagnosticMetricID: {
+		ID:          CAdvisorLabelDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(container_cpu_usage_seconds_total{container!="",pod!=""}[5m] %s)`,
+		Label:       "Expected cAdvsior labels available",
+		Description: "Determine if expected cAdvisor labels are present during last 5 minutes.",
+		DocLink:     fmt.Sprintf("%s#cadvisor-metrics-available", DocumentationBaseURL),
+	},
+	KSMVersionDiagnosticMetricID: {
+		ID:          KSMVersionDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(kube_persistentvolume_capacity_bytes[5m] %s)`,
+		Label:       "Expected kube-state-metrics version found",
+		Description: "Determine if metric in required kube-state-metrics version are present during last 5 minutes.",
+		DocLink:     fmt.Sprintf("%s#expected-kube-state-metrics-version-found", DocumentationBaseURL),
+	},
+	ScrapeIntervalDiagnosticMetricID: {
+		ID:          ScrapeIntervalDiagnosticMetricID,
+		QueryFmt:    `absent_over_time(prometheus_target_interval_length_seconds[5m]  %s)`,
+		Label:       "Expected Prometheus self-scrape metrics available",
+		Description: "Determine if prometheus has its own self-scraped metrics during the last 5 minutes.",
+	},
+	CPUThrottlingDiagnosticMetricID: {
+		ID: CPUThrottlingDiagnosticMetricID,
+		QueryFmt: `avg(increase(container_cpu_cfs_throttled_periods_total{container="cost-model"}[10m] %s)) by (container_name, pod_name, namespace)
+	/ avg(increase(container_cpu_cfs_periods_total{container="cost-model"}[10m] %s)) by (container_name, pod_name, namespace) > 0.2`,
+		Label:       "Kubecost is not CPU throttled",
+		Description: "Kubecost loading slowly? A kubecost component might be CPU throttled",
+	},
+}
+
 // QueuedPromRequest is a representation of a request waiting to be sent by the prometheus
 // client.
 type QueuedPromRequest struct {
@@ -66,75 +155,89 @@ func LogPrometheusClientState(client prometheus.Client) {
 }
 
 // GetPrometheusMetrics returns a list of the state of Prometheus metric used by kubecost using the provided client
-func GetPrometheusMetrics(client prometheus.Client, offset string) ([]*PrometheusDiagnostic, error) {
-	docs := "https://github.com/kubecost/docs/blob/master/diagnostics.md"
+func GetPrometheusMetrics(client prometheus.Client, offset string) PrometheusDiagnostics {
 	ctx := NewNamedContext(client, DiagnosticContextName)
 
-	result := []*PrometheusDiagnostic{
-		{
-			ID:          "cadvisorMetric",
-			Query:       fmt.Sprintf(`absent_over_time(container_cpu_usage_seconds_total[5m] %s)`, offset),
-			Label:       "cAdvsior metrics available",
-			Description: "Determine if cAdvisor metrics are available during last 5 minutes.",
-			DocLink:     fmt.Sprintf("%s#cadvisor-metrics-available", docs),
-		},
-		{
-			ID:          "ksmMetric",
-			Query:       fmt.Sprintf(`absent_over_time(kube_pod_container_resource_requests{resource="memory", unit="byte"}[5m]  %s)`, offset),
-			Label:       "Kube-state-metrics available",
-			Description: "Determine if metrics from kube-state-metrics are available during last 5 minutes.",
-			DocLink:     fmt.Sprintf("%s#kube-state-metrics-metrics-available", docs),
-		},
-		{
-			ID:          "kubecostMetric",
-			Query:       fmt.Sprintf(`absent_over_time(node_cpu_hourly_cost[5m]  %s)`, offset),
-			Label:       "Kubecost metrics available",
-			Description: "Determine if metrics from Kubecost are available during last 5 minutes.",
-		},
-		{
-			ID:          "neMetric",
-			Query:       fmt.Sprintf(`absent_over_time(node_cpu_seconds_total[5m]  %s)`, offset),
-			Label:       "Node-exporter metrics available",
-			Description: "Determine if metrics from node-exporter are available during last 5 minutes.",
-			DocLink:     fmt.Sprintf("%s#node-exporter-metrics-available", docs),
-		},
-		{
-			ID:          "cadvisorLabel",
-			Query:       fmt.Sprintf(`absent_over_time(container_cpu_usage_seconds_total{container!="",pod!=""}[5m]  %s)`, offset),
-			Label:       "Expected cAdvsior labels available",
-			Description: "Determine if expected cAdvisor labels are present during last 5 minutes.",
-			DocLink:     fmt.Sprintf("%s#cadvisor-metrics-available", docs),
-		},
-		{
-			ID:          "ksmVersion",
-			Query:       fmt.Sprintf(`absent_over_time(kube_persistentvolume_capacity_bytes[5m]  %s)`, offset),
-			Label:       "Expected kube-state-metrics version found",
-			Description: "Determine if metric in required kube-state-metrics version are present during last 5 minutes.",
-			DocLink:     fmt.Sprintf("%s#expected-kube-state-metrics-version-found", docs),
-		},
-		{
-			ID:          "scrapeInterval",
-			Query:       fmt.Sprintf(`absent_over_time(prometheus_target_interval_length_seconds[5m]  %s)`, offset),
-			Label:       "Expected Prometheus self-scrape metrics available",
-			Description: "Determine if prometheus has its own self-scraped metrics during the last 5 minutes.",
-		},
-		{
-			ID: "cpuThrottling",
-			Query: `avg(increase(container_cpu_cfs_throttled_periods_total{container="cost-model"}[10m])) by (container_name, pod_name, namespace)
-		/ avg(increase(container_cpu_cfs_periods_total{container="cost-model"}[10m])) by (container_name, pod_name, namespace) > 0.2`,
-			Label:       "Kubecost is not CPU throttled",
-			Description: "Kubecost loading slowly? A kubecost component might be CPU throttled",
-		},
-	}
-
-	for _, pd := range result {
+	var result []*PrometheusDiagnostic
+	for _, definition := range diagnosticDefinitions {
+		pd := definition.NewDiagnostic(offset)
 		err := pd.executePrometheusDiagnosticQuery(ctx)
+
+		// log the errror, append to results anyways, and continue
 		if err != nil {
 			log.Errorf(err.Error())
 		}
+		result = append(result, pd)
+	}
+
+	return result
+}
+
+// GetPrometheusMetricsByID returns a list of the state of specific Prometheus metrics by identifier.
+func GetPrometheusMetricsByID(ids []string, client prometheus.Client, offset string) PrometheusDiagnostics {
+	ctx := NewNamedContext(client, DiagnosticContextName)
+
+	var result []*PrometheusDiagnostic
+	for _, id := range ids {
+		if definition, ok := diagnosticDefinitions[id]; ok {
+			pd := definition.NewDiagnostic(offset)
+			err := pd.executePrometheusDiagnosticQuery(ctx)
+
+			// log the errror, append to results anyways, and continue
+			if err != nil {
+				log.Errorf(err.Error())
+			}
+			result = append(result, pd)
+		} else {
+			log.Warnf("Failed to find diagnostic definition for id: %s", id)
+		}
+	}
+
+	return result
+}
+
+// PrometheusDiagnostics is a PrometheusDiagnostic container with helper methods.
+type PrometheusDiagnostics []*PrometheusDiagnostic
+
+// HasFailure returns true if any of the diagnostic tests didn't pass.
+func (pd PrometheusDiagnostics) HasFailure() bool {
+	for _, p := range pd {
+		if !p.Passed {
+			return true
+		}
 	}
 
-	return result, nil
+	return false
+}
+
+// diagnosticDefinition is a definition of a diagnostic that can be used to create new
+// PrometheusDiagnostic instances using the definition's fields.
+type diagnosticDefinition struct {
+	ID          string
+	QueryFmt    string
+	Label       string
+	Description string
+	DocLink     string
+}
+
+// NewDiagnostic creates a new PrometheusDiagnostic instance using the provided definition data.
+func (pdd *diagnosticDefinition) NewDiagnostic(offset string) *PrometheusDiagnostic {
+	// FIXME: Any reasonable way to get the total number of replacements required in the query?
+	// FIXME: All of the other queries require a single offset replace, but CPUThrottle requires two.
+	var query string
+	if pdd.ID == CPUThrottlingDiagnosticMetricID {
+		query = fmt.Sprintf(pdd.QueryFmt, offset, offset)
+	} else {
+		query = fmt.Sprintf(pdd.QueryFmt, offset)
+	}
+
+	return &PrometheusDiagnostic{
+		ID:          pdd.ID,
+		Query:       query,
+		Label:       pdd.Label,
+		Description: pdd.Description,
+		DocLink:     pdd.DocLink,
+	}
 }
 
 // PrometheusDiagnostic holds information about a metric and the query to ensure it is functional

+ 8 - 0
pkg/util/defaults/defaults.go

@@ -0,0 +1,8 @@
+package defaults
+
+// Default[T] returns the default value for any generic type. This is helpful for generic
+// types where a type parameter can be a value type or pointer.
+func Default[T any]() T {
+	var t T
+	return t
+}

+ 5 - 3
pkg/util/retry/retry.go

@@ -5,6 +5,8 @@ import (
 	"fmt"
 	"math/rand"
 	"time"
+
+	"github.com/kubecost/cost-model/pkg/util/defaults"
 )
 
 // RetryCancellationErr is the error type that's returned if the retry is cancelled
@@ -16,15 +18,15 @@ func IsRetryCancelledError(err error) bool {
 }
 
 // Retry will run the f func until we receive a non error result up to the provided attempts or a cancellation.
-func Retry(ctx context.Context, f func() (interface{}, error), attempts uint, delay time.Duration) (interface{}, error) {
-	var result interface{}
+func Retry[T any](ctx context.Context, f func() (T, error), attempts uint, delay time.Duration) (T, error) {
+	var result T
 	var err error
 
 	d := delay
 	for r := attempts; r > 0; r-- {
 		select {
 		case <-ctx.Done():
-			return nil, RetryCancellationErr
+			return defaults.Default[T](), RetryCancellationErr
 		default:
 		}
 

+ 7 - 8
pkg/util/retry/retry_test.go

@@ -18,7 +18,7 @@ func TestPtrSliceRetry(t *testing.T) {
 
 	var count uint64 = 0
 
-	f := func() (interface{}, error) {
+	f := func() ([]*Obj, error) {
 		c := atomic.AddUint64(&count, 1)
 		fmt.Println("Try:", c)
 
@@ -33,9 +33,8 @@ func TestPtrSliceRetry(t *testing.T) {
 		return nil, fmt.Errorf("Failed: %d", c)
 	}
 
-	result, err := Retry(context.Background(), f, 5, time.Second)
-	objs, ok := result.([]*Obj)
-	if err != nil || !ok {
+	objs, err := Retry(context.Background(), f, 5, time.Second)
+	if err != nil {
 		t.Fatalf("Failed to correctly cast back to slice type")
 	}
 
@@ -48,12 +47,12 @@ func TestSuccessRetry(t *testing.T) {
 
 	var count uint64 = 0
 
-	f := func() (interface{}, error) {
+	f := func() (any, error) {
 		c := atomic.AddUint64(&count, 1)
 		fmt.Println("Try:", c)
 
 		if c == Expected {
-			return struct{}{}, nil
+			return nil, nil
 		}
 
 		return nil, fmt.Errorf("Failed: %d", c)
@@ -72,7 +71,7 @@ func TestFailRetry(t *testing.T) {
 	expectedError := fmt.Sprintf("Failed: %d", Expected)
 	var count uint64 = 0
 
-	f := func() (interface{}, error) {
+	f := func() (any, error) {
 		c := atomic.AddUint64(&count, 1)
 		fmt.Println("Try:", c)
 		return nil, fmt.Errorf("Failed: %d", c)
@@ -95,7 +94,7 @@ func TestCancelRetry(t *testing.T) {
 
 	var count uint64 = 0
 
-	f := func() (interface{}, error) {
+	f := func() (any, error) {
 		c := atomic.AddUint64(&count, 1)
 		fmt.Println("Try:", c)
 		return nil, fmt.Errorf("Failed: %d", c)