|
|
@@ -0,0 +1,130 @@
|
|
|
+package httputil
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "net/http/httptest"
|
|
|
+ "testing"
|
|
|
+)
|
|
|
+
|
|
|
+func TestBoundedClientHasTotalTimeout(t *testing.T) {
|
|
|
+ c := BoundedClient()
|
|
|
+ if c.Timeout != PricingTimeout {
|
|
|
+ t.Fatalf("expected total timeout %v, got %v", PricingTimeout, c.Timeout)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestBoundedClientIsShared(t *testing.T) {
|
|
|
+ if BoundedClient() != BoundedClient() {
|
|
|
+ t.Fatal("expected BoundedClient to return a shared instance")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamingClientHasNoTotalTimeout(t *testing.T) {
|
|
|
+ c := StreamingClient()
|
|
|
+ if c.Timeout != 0 {
|
|
|
+ t.Fatalf("streaming client must not set a total timeout, got %v", c.Timeout)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamingClientHasResponseHeaderTimeout(t *testing.T) {
|
|
|
+ c := StreamingClient()
|
|
|
+ tr, ok := c.Transport.(*http.Transport)
|
|
|
+ if !ok {
|
|
|
+ t.Fatalf("expected *http.Transport, got %T", c.Transport)
|
|
|
+ }
|
|
|
+ if tr.ResponseHeaderTimeout != PricingTimeout {
|
|
|
+ t.Fatalf("expected response-header timeout %v, got %v", PricingTimeout, tr.ResponseHeaderTimeout)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamingClientIsShared(t *testing.T) {
|
|
|
+ if StreamingClient() != StreamingClient() {
|
|
|
+ t.Fatal("expected StreamingClient to return a shared instance")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
+
|
|
|
+func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
|
+ return f(r)
|
|
|
+}
|
|
|
+
|
|
|
+// A base transport that is not a *http.Transport must not panic and must still
|
|
|
+// yield a usable client with the response-header timeout applied.
|
|
|
+func TestNewStreamingClientFallsBackWhenNotTransport(t *testing.T) {
|
|
|
+ // Honor the RoundTripper contract (non-nil response when error is nil), even
|
|
|
+ // though this base is only used to exercise the fallback and never round-trips.
|
|
|
+ base := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
|
|
+ return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
|
|
+ })
|
|
|
+ c := newStreamingClient(base)
|
|
|
+ tr, ok := c.Transport.(*http.Transport)
|
|
|
+ if !ok {
|
|
|
+ t.Fatalf("expected fallback *http.Transport, got %T", c.Transport)
|
|
|
+ }
|
|
|
+ if tr.ResponseHeaderTimeout != PricingTimeout {
|
|
|
+ t.Fatalf("expected response-header timeout %v, got %v", PricingTimeout, tr.ResponseHeaderTimeout)
|
|
|
+ }
|
|
|
+ // The fallback transport must still bound TLS handshake time.
|
|
|
+ if tr.TLSHandshakeTimeout == 0 {
|
|
|
+ t.Fatal("expected fallback transport to set a TLS handshake timeout")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamingGetReturnsBody(t *testing.T) {
|
|
|
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
|
+ _, _ = w.Write([]byte("price-data"))
|
|
|
+ }))
|
|
|
+ defer srv.Close()
|
|
|
+
|
|
|
+ resp, err := StreamingGet(context.Background(), srv.URL)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("unexpected error: %v", err)
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+ body, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("reading body: %v", err)
|
|
|
+ }
|
|
|
+ if string(body) != "price-data" {
|
|
|
+ t.Fatalf("expected body %q, got %q", "price-data", string(body))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamingGetHonorsCanceledContext(t *testing.T) {
|
|
|
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
|
+ _, _ = w.Write([]byte("ok"))
|
|
|
+ }))
|
|
|
+ defer srv.Close()
|
|
|
+
|
|
|
+ ctx, cancel := context.WithCancel(context.Background())
|
|
|
+ cancel() // cancel before the request runs
|
|
|
+
|
|
|
+ if _, err := StreamingGet(ctx, srv.URL); err == nil {
|
|
|
+ t.Fatal("expected an error from a canceled context, got nil")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamingGetRejectsBadURL(t *testing.T) {
|
|
|
+ if _, err := StreamingGet(context.Background(), "://not-a-url"); err == nil {
|
|
|
+ t.Fatal("expected an error building the request, got nil")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Cloning must not mutate the shared default transport.
|
|
|
+func TestNewStreamingClientClonesWithoutMutatingDefault(t *testing.T) {
|
|
|
+ def, ok := http.DefaultTransport.(*http.Transport)
|
|
|
+ if !ok {
|
|
|
+ t.Skip("default transport is not *http.Transport in this environment")
|
|
|
+ }
|
|
|
+ c := newStreamingClient(def)
|
|
|
+ tr := c.Transport.(*http.Transport)
|
|
|
+ if tr == def {
|
|
|
+ t.Fatal("expected a cloned transport, got the shared default")
|
|
|
+ }
|
|
|
+ if def.ResponseHeaderTimeout == PricingTimeout {
|
|
|
+ t.Fatal("mutated the shared default transport's ResponseHeaderTimeout")
|
|
|
+ }
|
|
|
+}
|