| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- package httputil
- import (
- "context"
- "fmt"
- "math"
- "net/http"
- "net/url"
- "strconv"
- "strings"
- "time"
- "github.com/opencost/opencost/pkg/util/mapper"
- )
- //--------------------------------------------------------------------------
- // QueryParams
- //--------------------------------------------------------------------------
- // valuesPrimitiveMap implements mapper.PrimitiveMap so we can build extra
- // functionality into the QueryParams interface.
- type valuesPrimitiveMap struct {
- url.Values
- }
- func (values valuesPrimitiveMap) Has(key string) bool {
- return values.Values.Has(key)
- }
- func (values valuesPrimitiveMap) Get(key string) string {
- return values.Values.Get(key)
- }
- func (values valuesPrimitiveMap) Set(key, value string) error {
- values.Values.Set(key, value)
- return nil
- }
- // QueryParams provides basic map access to URL values as well as providing
- // helpful additional functionality for validation.
- type QueryParams interface {
- mapper.PrimitiveMap
- // InvalidKeys returns the set of param keys which are not present in the
- // possible valid set. It is a set subtraction: present - valid = invalid
- //
- // Example usage to catch a typo:
- // qp.InvalidKeys([]string{"window", "aggregate", "filterClusters"}) ->
- // "filterClsuters"
- //
- // If qp contains no keys, then this should always return an empty slice/nil
- InvalidKeys(possibleValidKeys []string) (invalidKeys []string)
- }
- // queryParamsMap implements the QueryParams interface on top of
- // valuesPrimitiveMap.
- type queryParamsMap struct {
- values url.Values
- mapper.PrimitiveMap
- }
- // NewQueryParams creates a primitive map using the request query parameters
- func NewQueryParams(values url.Values) QueryParams {
- vpm := valuesPrimitiveMap{values}
- return &queryParamsMap{
- values: values,
- PrimitiveMap: mapper.NewMapper(vpm),
- }
- }
- // InvalidKeys performs a set difference: Params keys - possible valid keys.
- //
- // For now, dealing with cache busting parameters should be the handler's
- // responsibility.
- func (qpm *queryParamsMap) InvalidKeys(possibleValidKeys []string) []string {
- validMap := map[string]struct{}{}
- for _, validKey := range possibleValidKeys {
- validMap[validKey] = struct{}{}
- }
- var invalidKeys []string
- for key := range qpm.values {
- if _, ok := validMap[key]; !ok {
- invalidKeys = append(invalidKeys, key)
- }
- }
- return invalidKeys
- }
- //--------------------------------------------------------------------------
- // HTTP Context Utilities
- //--------------------------------------------------------------------------
- const (
- ContextWarning string = "Warning"
- ContextName string = "Name"
- ContextQuery string = "Query"
- )
- // GetWarning Extracts a warning message from the request context if it exists
- func GetWarning(r *http.Request) (warning string, ok bool) {
- warning, ok = r.Context().Value(ContextWarning).(string)
- return
- }
- // SetWarning Sets the warning context on the provided request and returns a new instance of the request
- // with the new context.
- func SetWarning(r *http.Request, warning string) *http.Request {
- ctx := context.WithValue(r.Context(), ContextWarning, warning)
- return r.WithContext(ctx)
- }
- // GetName Extracts a name value from the request context if it exists
- func GetName(r *http.Request) (name string, ok bool) {
- name, ok = r.Context().Value(ContextName).(string)
- return
- }
- // SetName Sets the name value on the provided request and returns a new instance of the request
- // with the new context.
- func SetName(r *http.Request, name string) *http.Request {
- ctx := context.WithValue(r.Context(), ContextName, name)
- return r.WithContext(ctx)
- }
- // GetQuery Extracts a query value from the request context if it exists
- func GetQuery(r *http.Request) (name string, ok bool) {
- name, ok = r.Context().Value(ContextQuery).(string)
- return
- }
- // SetQuery Sets the query value on the provided request and returns a new instance of the request
- // with the new context.
- func SetQuery(r *http.Request, query string) *http.Request {
- ctx := context.WithValue(r.Context(), ContextQuery, query)
- return r.WithContext(ctx)
- }
- //--------------------------------------------------------------------------
- // Package Funcs
- //--------------------------------------------------------------------------
- // IsRateLimited accepts a response and body to determine if either indicate
- // a rate limited return
- func IsRateLimited(resp *http.Response, body []byte) bool {
- return IsRateLimitedResponse(resp) || IsRateLimitedBody(resp, body)
- }
- // RateLimitedRetryFor returns the parsed Retry-After header relative to the
- // current time. If the Retry-After header does not exist, the defaultWait parameter
- // is returned.
- func RateLimitedRetryFor(resp *http.Response, defaultWait time.Duration, retry int) time.Duration {
- if resp.Header == nil {
- return ExponentialBackoffWaitFor(defaultWait, retry)
- }
- // Retry-After is either the number of seconds to wait or a target datetime (RFC1123)
- value := resp.Header.Get("Retry-After")
- if value == "" {
- return defaultWait
- }
- seconds, err := strconv.ParseInt(value, 10, 64)
- if err == nil {
- return time.Duration(seconds) * time.Second
- }
- // failed to parse an integer, try datetime RFC1123
- t, err := time.Parse(time.RFC1123, value)
- if err == nil {
- // return 0 if the datetime has already elapsed
- result := t.Sub(time.Now())
- if result < 0 {
- return 0
- }
- return result
- }
- // failed to parse datetime, return default
- return defaultWait
- }
- // ExpontentialBackoffWatiFor accepts a default wait duration and the current retry count
- // and returns a new duration
- func ExponentialBackoffWaitFor(defaultWait time.Duration, retry int) time.Duration {
- return time.Duration(math.Pow(2, float64(retry))*float64(defaultWait.Milliseconds())) * time.Millisecond
- }
- // IsRateLimitedResponse returns true if the status code is a 429 (TooManyRequests)
- func IsRateLimitedResponse(resp *http.Response) bool {
- return resp.StatusCode == http.StatusTooManyRequests
- }
- // IsRateLimitedBody attempts to determine if a response body indicates throttling
- // has occurred. This function is a result of some API providers (AWS) returning
- // a 400 status code instead of 429 for rate limit exceptions.
- func IsRateLimitedBody(resp *http.Response, body []byte) bool {
- // ignore non-400 status
- if resp.StatusCode < http.StatusBadRequest || resp.StatusCode >= http.StatusInternalServerError {
- return false
- }
- return strings.Contains(string(body), "ThrottlingException")
- }
- // HeaderString writes the request/response http.Header to a string.
- func HeaderString(h http.Header) string {
- var sb strings.Builder
- var first bool = true
- sb.WriteString("{ ")
- for k, vs := range h {
- if first {
- first = false
- } else {
- sb.WriteString(", ")
- }
- fmt.Fprintf(&sb, "%s: [ ", k)
- for idx, v := range vs {
- sb.WriteString(v)
- if idx != len(vs)-1 {
- sb.WriteString(", ")
- }
- }
- sb.WriteString(" ]")
- }
- sb.WriteString(" }")
- return sb.String()
- }
|