httputil.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package httputil
  2. import (
  3. "context"
  4. "fmt"
  5. "math"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/opencost/opencost/pkg/util/mapper"
  12. )
  13. //--------------------------------------------------------------------------
  14. // QueryParams
  15. //--------------------------------------------------------------------------
  16. // valuesPrimitiveMap implements mapper.PrimitiveMap so we can build extra
  17. // functionality into the QueryParams interface.
  18. type valuesPrimitiveMap struct {
  19. url.Values
  20. }
  21. func (values valuesPrimitiveMap) Has(key string) bool {
  22. return values.Values.Has(key)
  23. }
  24. func (values valuesPrimitiveMap) Get(key string) string {
  25. return values.Values.Get(key)
  26. }
  27. func (values valuesPrimitiveMap) Set(key, value string) error {
  28. values.Values.Set(key, value)
  29. return nil
  30. }
  31. // QueryParams provides basic map access to URL values as well as providing
  32. // helpful additional functionality for validation.
  33. type QueryParams interface {
  34. mapper.PrimitiveMap
  35. // InvalidKeys returns the set of param keys which are not present in the
  36. // possible valid set. It is a set subtraction: present - valid = invalid
  37. //
  38. // Example usage to catch a typo:
  39. // qp.InvalidKeys([]string{"window", "aggregate", "filterClusters"}) ->
  40. // "filterClsuters"
  41. //
  42. // If qp contains no keys, then this should always return an empty slice/nil
  43. InvalidKeys(possibleValidKeys []string) (invalidKeys []string)
  44. }
  45. // queryParamsMap implements the QueryParams interface on top of
  46. // valuesPrimitiveMap.
  47. type queryParamsMap struct {
  48. values url.Values
  49. mapper.PrimitiveMap
  50. }
  51. // NewQueryParams creates a primitive map using the request query parameters
  52. func NewQueryParams(values url.Values) QueryParams {
  53. vpm := valuesPrimitiveMap{values}
  54. return &queryParamsMap{
  55. values: values,
  56. PrimitiveMap: mapper.NewMapper(vpm),
  57. }
  58. }
  59. // InvalidKeys performs a set difference: Params keys - possible valid keys.
  60. //
  61. // For now, dealing with cache busting parameters should be the handler's
  62. // responsibility.
  63. func (qpm *queryParamsMap) InvalidKeys(possibleValidKeys []string) []string {
  64. validMap := map[string]struct{}{}
  65. for _, validKey := range possibleValidKeys {
  66. validMap[validKey] = struct{}{}
  67. }
  68. var invalidKeys []string
  69. for key := range qpm.values {
  70. if _, ok := validMap[key]; !ok {
  71. invalidKeys = append(invalidKeys, key)
  72. }
  73. }
  74. return invalidKeys
  75. }
  76. //--------------------------------------------------------------------------
  77. // HTTP Context Utilities
  78. //--------------------------------------------------------------------------
  79. const (
  80. ContextWarning string = "Warning"
  81. ContextName string = "Name"
  82. ContextQuery string = "Query"
  83. )
  84. // GetWarning Extracts a warning message from the request context if it exists
  85. func GetWarning(r *http.Request) (warning string, ok bool) {
  86. warning, ok = r.Context().Value(ContextWarning).(string)
  87. return
  88. }
  89. // SetWarning Sets the warning context on the provided request and returns a new instance of the request
  90. // with the new context.
  91. func SetWarning(r *http.Request, warning string) *http.Request {
  92. ctx := context.WithValue(r.Context(), ContextWarning, warning)
  93. return r.WithContext(ctx)
  94. }
  95. // GetName Extracts a name value from the request context if it exists
  96. func GetName(r *http.Request) (name string, ok bool) {
  97. name, ok = r.Context().Value(ContextName).(string)
  98. return
  99. }
  100. // SetName Sets the name value on the provided request and returns a new instance of the request
  101. // with the new context.
  102. func SetName(r *http.Request, name string) *http.Request {
  103. ctx := context.WithValue(r.Context(), ContextName, name)
  104. return r.WithContext(ctx)
  105. }
  106. // GetQuery Extracts a query value from the request context if it exists
  107. func GetQuery(r *http.Request) (name string, ok bool) {
  108. name, ok = r.Context().Value(ContextQuery).(string)
  109. return
  110. }
  111. // SetQuery Sets the query value on the provided request and returns a new instance of the request
  112. // with the new context.
  113. func SetQuery(r *http.Request, query string) *http.Request {
  114. ctx := context.WithValue(r.Context(), ContextQuery, query)
  115. return r.WithContext(ctx)
  116. }
  117. //--------------------------------------------------------------------------
  118. // Package Funcs
  119. //--------------------------------------------------------------------------
  120. // IsRateLimited accepts a response and body to determine if either indicate
  121. // a rate limited return
  122. func IsRateLimited(resp *http.Response, body []byte) bool {
  123. return IsRateLimitedResponse(resp) || IsRateLimitedBody(resp, body)
  124. }
  125. // RateLimitedRetryFor returns the parsed Retry-After header relative to the
  126. // current time. If the Retry-After header does not exist, the defaultWait parameter
  127. // is returned.
  128. func RateLimitedRetryFor(resp *http.Response, defaultWait time.Duration, retry int) time.Duration {
  129. if resp.Header == nil {
  130. return ExponentialBackoffWaitFor(defaultWait, retry)
  131. }
  132. // Retry-After is either the number of seconds to wait or a target datetime (RFC1123)
  133. value := resp.Header.Get("Retry-After")
  134. if value == "" {
  135. return defaultWait
  136. }
  137. seconds, err := strconv.ParseInt(value, 10, 64)
  138. if err == nil {
  139. return time.Duration(seconds) * time.Second
  140. }
  141. // failed to parse an integer, try datetime RFC1123
  142. t, err := time.Parse(time.RFC1123, value)
  143. if err == nil {
  144. // return 0 if the datetime has already elapsed
  145. result := t.Sub(time.Now())
  146. if result < 0 {
  147. return 0
  148. }
  149. return result
  150. }
  151. // failed to parse datetime, return default
  152. return defaultWait
  153. }
  154. // ExpontentialBackoffWatiFor accepts a default wait duration and the current retry count
  155. // and returns a new duration
  156. func ExponentialBackoffWaitFor(defaultWait time.Duration, retry int) time.Duration {
  157. return time.Duration(math.Pow(2, float64(retry))*float64(defaultWait.Milliseconds())) * time.Millisecond
  158. }
  159. // IsRateLimitedResponse returns true if the status code is a 429 (TooManyRequests)
  160. func IsRateLimitedResponse(resp *http.Response) bool {
  161. return resp.StatusCode == http.StatusTooManyRequests
  162. }
  163. // IsRateLimitedBody attempts to determine if a response body indicates throttling
  164. // has occurred. This function is a result of some API providers (AWS) returning
  165. // a 400 status code instead of 429 for rate limit exceptions.
  166. func IsRateLimitedBody(resp *http.Response, body []byte) bool {
  167. // ignore non-400 status
  168. if resp.StatusCode < http.StatusBadRequest || resp.StatusCode >= http.StatusInternalServerError {
  169. return false
  170. }
  171. return strings.Contains(string(body), "ThrottlingException")
  172. }
  173. // HeaderString writes the request/response http.Header to a string.
  174. func HeaderString(h http.Header) string {
  175. var sb strings.Builder
  176. var first bool = true
  177. sb.WriteString("{ ")
  178. for k, vs := range h {
  179. if first {
  180. first = false
  181. } else {
  182. sb.WriteString(", ")
  183. }
  184. fmt.Fprintf(&sb, "%s: [ ", k)
  185. for idx, v := range vs {
  186. sb.WriteString(v)
  187. if idx != len(vs)-1 {
  188. sb.WriteString(", ")
  189. }
  190. }
  191. sb.WriteString(" ]")
  192. }
  193. sb.WriteString(" }")
  194. return sb.String()
  195. }