| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- package httputil
- import (
- "context"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- )
- func TestBoundedClientHasTotalTimeout(t *testing.T) {
- c := BoundedClient()
- if c.Timeout != PricingTimeout {
- t.Fatalf("expected total timeout %v, got %v", PricingTimeout, c.Timeout)
- }
- }
- func TestBoundedClientIsShared(t *testing.T) {
- if BoundedClient() != BoundedClient() {
- t.Fatal("expected BoundedClient to return a shared instance")
- }
- }
- func TestStreamingClientHasNoTotalTimeout(t *testing.T) {
- c := StreamingClient()
- if c.Timeout != 0 {
- t.Fatalf("streaming client must not set a total timeout, got %v", c.Timeout)
- }
- }
- func TestStreamingClientHasResponseHeaderTimeout(t *testing.T) {
- c := StreamingClient()
- tr, ok := c.Transport.(*http.Transport)
- if !ok {
- t.Fatalf("expected *http.Transport, got %T", c.Transport)
- }
- if tr.ResponseHeaderTimeout != PricingTimeout {
- t.Fatalf("expected response-header timeout %v, got %v", PricingTimeout, tr.ResponseHeaderTimeout)
- }
- }
- func TestStreamingClientIsShared(t *testing.T) {
- if StreamingClient() != StreamingClient() {
- t.Fatal("expected StreamingClient to return a shared instance")
- }
- }
- type roundTripperFunc func(*http.Request) (*http.Response, error)
- func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
- return f(r)
- }
- // A base transport that is not a *http.Transport must not panic and must still
- // yield a usable client with the response-header timeout applied.
- func TestNewStreamingClientFallsBackWhenNotTransport(t *testing.T) {
- // Honor the RoundTripper contract (non-nil response when error is nil), even
- // though this base is only used to exercise the fallback and never round-trips.
- base := roundTripperFunc(func(*http.Request) (*http.Response, error) {
- return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
- })
- c := newStreamingClient(base)
- tr, ok := c.Transport.(*http.Transport)
- if !ok {
- t.Fatalf("expected fallback *http.Transport, got %T", c.Transport)
- }
- if tr.ResponseHeaderTimeout != PricingTimeout {
- t.Fatalf("expected response-header timeout %v, got %v", PricingTimeout, tr.ResponseHeaderTimeout)
- }
- // The fallback transport must still bound TLS handshake time.
- if tr.TLSHandshakeTimeout == 0 {
- t.Fatal("expected fallback transport to set a TLS handshake timeout")
- }
- }
- func TestStreamingGetReturnsBody(t *testing.T) {
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- _, _ = w.Write([]byte("price-data"))
- }))
- defer srv.Close()
- resp, err := StreamingGet(context.Background(), srv.URL)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("reading body: %v", err)
- }
- if string(body) != "price-data" {
- t.Fatalf("expected body %q, got %q", "price-data", string(body))
- }
- }
- func TestStreamingGetHonorsCanceledContext(t *testing.T) {
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- _, _ = w.Write([]byte("ok"))
- }))
- defer srv.Close()
- ctx, cancel := context.WithCancel(context.Background())
- cancel() // cancel before the request runs
- if _, err := StreamingGet(ctx, srv.URL); err == nil {
- t.Fatal("expected an error from a canceled context, got nil")
- }
- }
- func TestStreamingGetRejectsBadURL(t *testing.T) {
- if _, err := StreamingGet(context.Background(), "://not-a-url"); err == nil {
- t.Fatal("expected an error building the request, got nil")
- }
- }
- // Cloning must not mutate the shared default transport.
- func TestNewStreamingClientClonesWithoutMutatingDefault(t *testing.T) {
- def, ok := http.DefaultTransport.(*http.Transport)
- if !ok {
- t.Skip("default transport is not *http.Transport in this environment")
- }
- c := newStreamingClient(def)
- tr := c.Transport.(*http.Transport)
- if tr == def {
- t.Fatal("expected a cloned transport, got the shared default")
- }
- if def.ResponseHeaderTimeout == PricingTimeout {
- t.Fatal("mutated the shared default transport's ResponseHeaderTimeout")
- }
- }
|