2
0

roundtrip_test.go 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. package httputil
  2. import (
  3. "fmt"
  4. "net/http"
  5. "reflect"
  6. "testing"
  7. )
  8. type reqValidateRoundTripper struct {
  9. expectedReq *http.Request
  10. }
  11. func (rt *reqValidateRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
  12. if !reflect.DeepEqual(r, rt.expectedReq) {
  13. return nil, fmt.Errorf("expected req %v, got %v", rt.expectedReq, r)
  14. }
  15. return nil, nil
  16. }
  17. func TestUserAgentTransport(t *testing.T) {
  18. for _, tc := range []struct {
  19. name string
  20. ua string
  21. req *http.Request
  22. expReq *http.Request
  23. }{
  24. {
  25. name: "opencost",
  26. ua: "opencost",
  27. req: &http.Request{},
  28. expReq: &http.Request{Header: http.Header{"User-Agent": []string{"opencost"}}},
  29. },
  30. {
  31. name: "foo",
  32. ua: "foo",
  33. req: &http.Request{},
  34. expReq: &http.Request{Header: http.Header{"User-Agent": []string{"foo"}}},
  35. },
  36. {
  37. name: "overwrite user agent if exists",
  38. ua: "opencost",
  39. req: &http.Request{Header: http.Header{"User-Agent": []string{"foo"}}},
  40. expReq: &http.Request{Header: http.Header{"User-Agent": []string{"opencost"}}},
  41. },
  42. } {
  43. t.Run(tc.name, func(t *testing.T) {
  44. rt := NewUserAgentTransport(tc.ua, &reqValidateRoundTripper{
  45. expectedReq: tc.expReq,
  46. })
  47. _, err := rt.RoundTrip(tc.req)
  48. if err != nil {
  49. t.Error(err)
  50. }
  51. })
  52. }
  53. }