httputil.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package httputil
  2. import (
  3. "context"
  4. "fmt"
  5. "math"
  6. "net/http"
  7. "net/url"
  8. "strconv"
  9. "strings"
  10. "time"
  11. "github.com/opencost/opencost/pkg/util/mapper"
  12. )
  13. //--------------------------------------------------------------------------
  14. // QueryParams
  15. //--------------------------------------------------------------------------
  16. // valuesPrimitiveMap implements mapper.PrimitiveMap so we can build extra
  17. // functionality into the QueryParams interface.
  18. type valuesPrimitiveMap struct {
  19. url.Values
  20. }
  21. func (values valuesPrimitiveMap) Has(key string) bool {
  22. return values.Values.Has(key)
  23. }
  24. func (values valuesPrimitiveMap) Get(key string) string {
  25. return values.Values.Get(key)
  26. }
  27. func (values valuesPrimitiveMap) Set(key, value string) error {
  28. values.Values.Set(key, value)
  29. return nil
  30. }
  31. // QueryParams provides basic map access to URL values as well as providing
  32. // helpful additional functionality for validation.
  33. type QueryParams interface {
  34. mapper.PrimitiveMap
  35. // InvalidKeys returns the set of param keys which are not present in the
  36. // possible valid set. It is a set subtraction: present - valid = invalid
  37. //
  38. // Example usage to catch a typo:
  39. // qp.InvalidKeys([]string{"window", "aggregate", "filterClusters"}) ->
  40. // "filterClsuters"
  41. //
  42. InvalidKeys(possibleValidKeys []string) (invalidKeys []string)
  43. }
  44. // queryParamsMap implements the QueryParams interface on top of
  45. // valuesPrimitiveMap.
  46. type queryParamsMap struct {
  47. values url.Values
  48. mapper.PrimitiveMap
  49. }
  50. // NewQueryParams creates a primitive map using the request query parameters
  51. func NewQueryParams(values url.Values) QueryParams {
  52. vpm := valuesPrimitiveMap{values}
  53. return &queryParamsMap{
  54. values: values,
  55. PrimitiveMap: mapper.NewMapper(vpm),
  56. }
  57. }
  58. // TODO: How to handle "cache buster" params?
  59. // InvalidKeys(expectedKeys []string) (invalid []string)
  60. func (qpm *queryParamsMap) InvalidKeys(possibleValidKeys []string) []string {
  61. validMap := map[string]struct{}{}
  62. for _, validKey := range possibleValidKeys {
  63. validMap[validKey] = struct{}{}
  64. }
  65. var invalidKeys []string
  66. for key := range qpm.values {
  67. if _, ok := validMap[key]; !ok {
  68. invalidKeys = append(invalidKeys, key)
  69. }
  70. }
  71. return invalidKeys
  72. }
  73. //--------------------------------------------------------------------------
  74. // HTTP Context Utilities
  75. //--------------------------------------------------------------------------
  76. const (
  77. ContextWarning string = "Warning"
  78. ContextName string = "Name"
  79. ContextQuery string = "Query"
  80. )
  81. // GetWarning Extracts a warning message from the request context if it exists
  82. func GetWarning(r *http.Request) (warning string, ok bool) {
  83. warning, ok = r.Context().Value(ContextWarning).(string)
  84. return
  85. }
  86. // SetWarning Sets the warning context on the provided request and returns a new instance of the request
  87. // with the new context.
  88. func SetWarning(r *http.Request, warning string) *http.Request {
  89. ctx := context.WithValue(r.Context(), ContextWarning, warning)
  90. return r.WithContext(ctx)
  91. }
  92. // GetName Extracts a name value from the request context if it exists
  93. func GetName(r *http.Request) (name string, ok bool) {
  94. name, ok = r.Context().Value(ContextName).(string)
  95. return
  96. }
  97. // SetName Sets the name value on the provided request and returns a new instance of the request
  98. // with the new context.
  99. func SetName(r *http.Request, name string) *http.Request {
  100. ctx := context.WithValue(r.Context(), ContextName, name)
  101. return r.WithContext(ctx)
  102. }
  103. // GetQuery Extracts a query value from the request context if it exists
  104. func GetQuery(r *http.Request) (name string, ok bool) {
  105. name, ok = r.Context().Value(ContextQuery).(string)
  106. return
  107. }
  108. // SetQuery Sets the query value on the provided request and returns a new instance of the request
  109. // with the new context.
  110. func SetQuery(r *http.Request, query string) *http.Request {
  111. ctx := context.WithValue(r.Context(), ContextQuery, query)
  112. return r.WithContext(ctx)
  113. }
  114. //--------------------------------------------------------------------------
  115. // Package Funcs
  116. //--------------------------------------------------------------------------
  117. // IsRateLimited accepts a response and body to determine if either indicate
  118. // a rate limited return
  119. func IsRateLimited(resp *http.Response, body []byte) bool {
  120. return IsRateLimitedResponse(resp) || IsRateLimitedBody(resp, body)
  121. }
  122. // RateLimitedRetryFor returns the parsed Retry-After header relative to the
  123. // current time. If the Retry-After header does not exist, the defaultWait parameter
  124. // is returned.
  125. func RateLimitedRetryFor(resp *http.Response, defaultWait time.Duration, retry int) time.Duration {
  126. if resp.Header == nil {
  127. return ExponentialBackoffWaitFor(defaultWait, retry)
  128. }
  129. // Retry-After is either the number of seconds to wait or a target datetime (RFC1123)
  130. value := resp.Header.Get("Retry-After")
  131. if value == "" {
  132. return defaultWait
  133. }
  134. seconds, err := strconv.ParseInt(value, 10, 64)
  135. if err == nil {
  136. return time.Duration(seconds) * time.Second
  137. }
  138. // failed to parse an integer, try datetime RFC1123
  139. t, err := time.Parse(time.RFC1123, value)
  140. if err == nil {
  141. // return 0 if the datetime has already elapsed
  142. result := t.Sub(time.Now())
  143. if result < 0 {
  144. return 0
  145. }
  146. return result
  147. }
  148. // failed to parse datetime, return default
  149. return defaultWait
  150. }
  151. // ExpontentialBackoffWatiFor accepts a default wait duration and the current retry count
  152. // and returns a new duration
  153. func ExponentialBackoffWaitFor(defaultWait time.Duration, retry int) time.Duration {
  154. return time.Duration(math.Pow(2, float64(retry))*float64(defaultWait.Milliseconds())) * time.Millisecond
  155. }
  156. // IsRateLimitedResponse returns true if the status code is a 429 (TooManyRequests)
  157. func IsRateLimitedResponse(resp *http.Response) bool {
  158. return resp.StatusCode == http.StatusTooManyRequests
  159. }
  160. // IsRateLimitedBody attempts to determine if a response body indicates throttling
  161. // has occurred. This function is a result of some API providers (AWS) returning
  162. // a 400 status code instead of 429 for rate limit exceptions.
  163. func IsRateLimitedBody(resp *http.Response, body []byte) bool {
  164. // ignore non-400 status
  165. if resp.StatusCode < http.StatusBadRequest || resp.StatusCode >= http.StatusInternalServerError {
  166. return false
  167. }
  168. return strings.Contains(string(body), "ThrottlingException")
  169. }
  170. // HeaderString writes the request/response http.Header to a string.
  171. func HeaderString(h http.Header) string {
  172. var sb strings.Builder
  173. var first bool = true
  174. sb.WriteString("{ ")
  175. for k, vs := range h {
  176. if first {
  177. first = false
  178. } else {
  179. sb.WriteString(", ")
  180. }
  181. fmt.Fprintf(&sb, "%s: [ ", k)
  182. for idx, v := range vs {
  183. sb.WriteString(v)
  184. if idx != len(vs)-1 {
  185. sb.WriteString(", ")
  186. }
  187. }
  188. sb.WriteString(" ]")
  189. }
  190. sb.WriteString(" }")
  191. return sb.String()
  192. }