Просмотр исходного кода

Merge pull request #1273 from porter-dev/belanger/por-77-api-500-level-fixes

[POR-77] Websocket logic fixes and removing more 500-level errors
abelanger5 4 лет назад
Родитель
Сommit
367b9ef255
34 измененных файлов с 486 добавлено и 207 удалено
  1. 1 1
      api/server/authn/handler.go
  2. 2 2
      api/server/authz/cluster.go
  3. 2 2
      api/server/authz/git_installation.go
  4. 2 2
      api/server/authz/helm_repo.go
  5. 2 2
      api/server/authz/infra.go
  6. 2 2
      api/server/authz/invite.go
  7. 3 2
      api/server/authz/policy.go
  8. 11 1
      api/server/authz/project.go
  9. 2 2
      api/server/authz/registry.go
  10. 3 3
      api/server/authz/release.go
  11. 3 8
      api/server/handlers/cluster/stream_helm_release.go
  12. 3 8
      api/server/handlers/cluster/stream_status.go
  13. 6 1
      api/server/handlers/handler.go
  14. 3 8
      api/server/handlers/infra/stream_logs.go
  15. 10 8
      api/server/handlers/namespace/stream_pod_logs.go
  16. 2 0
      api/server/router/cluster.go
  17. 1 0
      api/server/router/infra.go
  18. 11 0
      api/server/router/middleware/content_type_json.go
  19. 31 0
      api/server/router/middleware/panic.go
  20. 59 0
      api/server/router/middleware/request_logger.go
  21. 45 0
      api/server/router/middleware/websocket.go
  22. 1 0
      api/server/router/namespace.go
  23. 11 80
      api/server/router/router.go
  24. 16 13
      api/server/shared/apierrors/errors.go
  25. 1 1
      api/server/shared/apitest/request.go
  26. 1 1
      api/server/shared/config/config.go
  27. 9 6
      api/server/shared/config/loader/loader.go
  28. 2 2
      api/server/shared/reader.go
  29. 78 0
      api/server/shared/websocket/response_writer.go
  30. 34 0
      api/server/shared/websocket/upgrader.go
  31. 3 3
      api/server/shared/writer.go
  32. 5 0
      api/types/request.go
  33. 117 39
      internal/kubernetes/agent.go
  34. 4 10
      internal/kubernetes/provisioner/resource_stream.go

+ 1 - 1
api/server/authn/handler.go

@@ -153,7 +153,7 @@ func (authn *AuthN) nextWithUserID(w http.ResponseWriter, r *http.Request, userI
 func (authn *AuthN) sendForbiddenError(err error, w http.ResponseWriter, r *http.Request) {
 	reqErr := apierrors.NewErrForbidden(err)
 
-	apierrors.HandleAPIError(authn.config, w, r, reqErr)
+	apierrors.HandleAPIError(authn.config, w, r, reqErr, true)
 }
 
 var errInvalidToken = fmt.Errorf("authorization header exists, but token is not valid")

+ 2 - 2
api/server/authz/cluster.go

@@ -52,9 +52,9 @@ func (p *ClusterScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("cluster with id %d not found in project %d", clusterID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 2
api/server/authz/git_installation.go

@@ -46,12 +46,12 @@ func (p *GitInstallationScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *ht
 	gitInstallation, err := p.config.Repo.GithubAppInstallation().ReadGithubAppInstallationByInstallationID(gitInstallationID)
 
 	if err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 
 	if err := p.doesUserHaveGitInstallationAccess(user.GithubAppIntegrationID, gitInstallationID); err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 

+ 2 - 2
api/server/authz/helm_repo.go

@@ -45,9 +45,9 @@ func (p *HelmRepoScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Requ
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("helm repo with id %d not found in project %d", helmRepoID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 2
api/server/authz/infra.go

@@ -45,9 +45,9 @@ func (p *InfraScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("infra with id %d not found in project %d", infraID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 2
api/server/authz/invite.go

@@ -45,9 +45,9 @@ func (p *InviteScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("invite with id %d not found in project %d", inviteID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 3 - 2
api/server/authz/policy.go

@@ -43,7 +43,7 @@ func (h *PolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	reqScopes, reqErr := getRequestActionForEndpoint(r, h.endpointMeta)
 
 	if reqErr != nil {
-		apierrors.HandleAPIError(h.config, w, r, reqErr)
+		apierrors.HandleAPIError(h.config, w, r, reqErr, true)
 		return
 	}
 
@@ -54,7 +54,7 @@ func (h *PolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	policyDocs, reqErr := h.loader.LoadPolicyDocuments(user.ID, projID)
 
 	if reqErr != nil {
-		apierrors.HandleAPIError(h.config, w, r, reqErr)
+		apierrors.HandleAPIError(h.config, w, r, reqErr, true)
 		return
 	}
 
@@ -67,6 +67,7 @@ func (h *PolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			w,
 			r,
 			apierrors.NewErrForbidden(fmt.Errorf("policy forbids action for user %d in project %d", user.ID, projID)),
+			true,
 		)
 
 		return

+ 11 - 1
api/server/authz/project.go

@@ -2,12 +2,14 @@ package authz
 
 import (
 	"context"
+	"fmt"
 	"net/http"
 
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
+	"gorm.io/gorm"
 )
 
 type ProjectScopedFactory struct {
@@ -38,7 +40,15 @@ func (p *ProjectScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 	project, err := p.config.Repo.Project().ReadProject(projID)
 
 	if err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		if err == gorm.ErrRecordNotFound {
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
+				fmt.Errorf("project not found with id %d", projID),
+			), true)
+
+			return
+		}
+
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 

+ 2 - 2
api/server/authz/registry.go

@@ -45,9 +45,9 @@ func (p *RegistryScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Requ
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("registry with id %d not found in project %d", registryID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 3 - 3
api/server/authz/release.go

@@ -40,7 +40,7 @@ func (p *ReleaseScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 	helmAgent, err := p.agentGetter.GetHelmAgent(r, cluster, "")
 
 	if err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 
@@ -60,9 +60,9 @@ func (p *ReleaseScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrPassThroughToClient(
 				fmt.Errorf("release not found"),
 				http.StatusNotFound,
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 3 - 8
api/server/handlers/cluster/stream_helm_release.go

@@ -8,6 +8,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
 )
@@ -29,13 +30,7 @@ func NewStreamHelmReleaseHandler(
 }
 
 func (c *StreamHelmReleaseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	request := &types.StreamHelmReleaseRequest{}
 
 	if ok := c.DecodeAndValidate(w, r, request); !ok {
@@ -51,7 +46,7 @@ func (c *StreamHelmReleaseHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
 		return
 	}
 
-	err = agent.StreamHelmReleases(conn, request.Namespace, request.Charts, request.Selectors)
+	err = agent.StreamHelmReleases(request.Namespace, request.Charts, request.Selectors, safeRW)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 3 - 8
api/server/handlers/cluster/stream_status.go

@@ -9,6 +9,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/server/shared/requestutils"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
 )
@@ -30,13 +31,7 @@ func NewStreamStatusHandler(
 }
 
 func (c *StreamStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	request := &types.StreamStatusRequest{}
 
 	if ok := c.DecodeAndValidate(w, r, request); !ok {
@@ -54,7 +49,7 @@ func (c *StreamStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 
 	kind, _ := requestutils.GetURLParamString(r, types.URLParamKind)
 
-	err = agent.StreamControllerStatus(conn, kind, request.Selectors)
+	err = agent.StreamControllerStatus(kind, request.Selectors, safeRW)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 6 - 1
api/server/handlers/handler.go

@@ -16,6 +16,7 @@ type PorterHandler interface {
 	Config() *config.Config
 	Repo() repository.Repository
 	HandleAPIError(w http.ResponseWriter, r *http.Request, err apierrors.RequestError)
+	HandleAPIErrorNoWrite(w http.ResponseWriter, r *http.Request, err apierrors.RequestError)
 	PopulateOAuthSession(w http.ResponseWriter, r *http.Request, state string, isProject bool) error
 }
 
@@ -57,7 +58,11 @@ func (d *DefaultPorterHandler) Repo() repository.Repository {
 }
 
 func (d *DefaultPorterHandler) HandleAPIError(w http.ResponseWriter, r *http.Request, err apierrors.RequestError) {
-	apierrors.HandleAPIError(d.Config(), w, r, err)
+	apierrors.HandleAPIError(d.Config(), w, r, err, true)
+}
+
+func (d *DefaultPorterHandler) HandleAPIErrorNoWrite(w http.ResponseWriter, r *http.Request, err apierrors.RequestError) {
+	apierrors.HandleAPIError(d.Config(), w, r, err, false)
 }
 
 func (d *DefaultPorterHandler) WriteResult(w http.ResponseWriter, r *http.Request, v interface{}) {

+ 3 - 8
api/server/handlers/infra/stream_logs.go

@@ -7,6 +7,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/adapter"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner"
@@ -27,13 +28,7 @@ func NewInfraStreamLogsHandler(
 }
 
 func (c *InfraStreamLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	infra, _ := r.Context().Value(types.InfraScope).(*models.Infra)
 
 	client, err := adapter.NewRedisClient(c.Config().RedisConf)
@@ -43,7 +38,7 @@ func (c *InfraStreamLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
 		return
 	}
 
-	err = provisioner.ResourceStream(client, infra.GetUniqueName(), conn)
+	err = provisioner.ResourceStream(client, infra.GetUniqueName(), safeRW)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 10 - 8
api/server/handlers/namespace/stream_pod_logs.go

@@ -11,6 +11,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/server/shared/requestutils"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/kubernetes"
 	"github.com/porter-dev/porter/internal/models"
@@ -33,13 +34,7 @@ func NewStreamPodLogsHandler(
 }
 
 func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	namespace := r.Context().Value(types.NamespaceScope).(string)
 	name, _ := requestutils.GetURLParamString(r, types.URLParamPodName)
 
@@ -52,7 +47,7 @@ func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 		return
 	}
 
-	err = agent.GetPodLogs(namespace, name, conn)
+	err = agent.GetPodLogs(namespace, name, safeRW)
 
 	if targetErr := kubernetes.IsNotFoundError; errors.Is(err, targetErr) {
 		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(
@@ -60,6 +55,13 @@ func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 			http.StatusNotFound,
 		))
 
+		return
+	} else if brErr := (kubernetes.BadRequestError{}); errors.As(err, &targetErr) {
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(
+			&brErr,
+			http.StatusBadRequest,
+		))
+
 		return
 	} else if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 2 - 0
api/server/router/cluster.go

@@ -518,6 +518,7 @@ func getClusterRoutes(
 				types.ProjectScope,
 				types.ClusterScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 
@@ -551,6 +552,7 @@ func getClusterRoutes(
 				types.ProjectScope,
 				types.ClusterScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 

+ 1 - 0
api/server/router/infra.go

@@ -121,6 +121,7 @@ func getInfraRoutes(
 				types.ProjectScope,
 				types.InfraScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 

+ 11 - 0
api/server/router/middleware/content_type_json.go

@@ -0,0 +1,11 @@
+package middleware
+
+import "net/http"
+
+// ContentTypeJSON sets the content type for requests to application/json
+func ContentTypeJSON(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("Content-Type", "application/json;charset=utf8")
+		next.ServeHTTP(w, r)
+	})
+}

+ 31 - 0
api/server/router/middleware/panic.go

@@ -0,0 +1,31 @@
+package middleware
+
+import (
+	"fmt"
+	"net/http"
+
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+)
+
+type PanicMiddleware struct {
+	config *config.Config
+}
+
+func NewPanicMiddleware(config *config.Config) *PanicMiddleware {
+	return &PanicMiddleware{config}
+}
+
+func (pmw *PanicMiddleware) Middleware(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		defer func() {
+			err := recover()
+
+			if err != nil {
+				apierrors.HandleAPIError(pmw.config, w, r, apierrors.NewErrInternal(fmt.Errorf("%v", err)), true)
+			}
+		}()
+
+		next.ServeHTTP(w, r)
+	})
+}

+ 59 - 0
api/server/router/middleware/request_logger.go

@@ -0,0 +1,59 @@
+package middleware
+
+import (
+	"bufio"
+	"errors"
+	"net"
+	"net/http"
+	"time"
+
+	"github.com/porter-dev/porter/internal/logger"
+)
+
+type requestLoggerResponseWriter struct {
+	http.ResponseWriter
+	statusCode int
+}
+
+func newRequestLoggerResponseWriter(w http.ResponseWriter) *requestLoggerResponseWriter {
+	return &requestLoggerResponseWriter{w, http.StatusOK}
+}
+
+func (rw *requestLoggerResponseWriter) WriteHeader(code int) {
+	rw.statusCode = code
+	rw.ResponseWriter.WriteHeader(code)
+}
+
+func (rw *requestLoggerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	h, ok := rw.ResponseWriter.(http.Hijacker)
+	if !ok {
+		return nil, nil, errors.New("ResponseWriter Interface does not support hijacking")
+	}
+	return h.Hijack()
+}
+
+type RequestLoggerMiddleware struct {
+	logger *logger.Logger
+}
+
+func NewRequestLoggerMiddleware(logger *logger.Logger) *RequestLoggerMiddleware {
+	return &RequestLoggerMiddleware{logger}
+}
+
+func (mw *RequestLoggerMiddleware) Middleware(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		start := time.Now()
+		rw := newRequestLoggerResponseWriter(w)
+
+		next.ServeHTTP(rw, r)
+
+		latency := time.Since(start)
+
+		event := mw.logger.Info().Dur("latency", latency).Int("status", rw.statusCode)
+
+		logger.AddLoggingContextScopes(r.Context(), event)
+		logger.AddLoggingRequestMeta(r, event)
+
+		event.Send()
+	})
+}

+ 45 - 0
api/server/router/middleware/websocket.go

@@ -0,0 +1,45 @@
+package middleware
+
+import (
+	"context"
+	"errors"
+	"net/http"
+
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
+	"github.com/porter-dev/porter/api/types"
+)
+
+type WebsocketMiddleware struct {
+	config *config.Config
+}
+
+func NewWebsocketMiddleware(config *config.Config) *WebsocketMiddleware {
+	return &WebsocketMiddleware{config}
+}
+
+func (wm *WebsocketMiddleware) Middleware(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		conn, newRW, safeRW, err := wm.config.WSUpgrader.Upgrade(w, r, nil)
+
+		if err != nil {
+			if errors.Is(err, websocket.UpgraderCheckOriginErr) {
+				apierrors.HandleAPIError(wm.config, w, r, apierrors.NewErrForbidden(err), true)
+				return
+			} else {
+				apierrors.HandleAPIError(wm.config, w, r, apierrors.NewErrInternal(err), false)
+				return
+			}
+		}
+
+		w = newRW
+		defer conn.Close()
+
+		ctx := r.Context()
+		ctx = context.WithValue(ctx, types.RequestCtxWebsocketKey, safeRW)
+
+		r = r.Clone(ctx)
+		next.ServeHTTP(w, r)
+	})
+}

+ 1 - 0
api/server/router/namespace.go

@@ -283,6 +283,7 @@ func getNamespaceRoutes(
 				types.ClusterScope,
 				types.NamespaceScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 

+ 11 - 80
api/server/router/router.go

@@ -1,25 +1,19 @@
 package router
 
 import (
-	"bufio"
-	"errors"
-	"fmt"
-	"net"
 	"net/http"
 	"os"
 	"path"
 	"strings"
-	"time"
 
 	"github.com/go-chi/chi"
 	"github.com/porter-dev/porter/api/server/authn"
 	"github.com/porter-dev/porter/api/server/authz"
 	"github.com/porter-dev/porter/api/server/authz/policy"
+	"github.com/porter-dev/porter/api/server/router/middleware"
 	"github.com/porter-dev/porter/api/server/shared"
-	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/types"
-	"github.com/porter-dev/porter/internal/logger"
 )
 
 func NewAPIRouter(config *config.Config) *chi.Mux {
@@ -54,14 +48,14 @@ func NewAPIRouter(config *config.Config) *chi.Mux {
 	)
 
 	userRegisterer := NewUserScopedRegisterer(projRegisterer)
-	panicMW := &PanicMiddleware{config}
+	panicMW := middleware.NewPanicMiddleware(config)
 
 	r.Route("/api", func(r chi.Router) {
 		// set panic middleware for all API endpoints to catch panics
 		r.Use(panicMW.Middleware)
 
 		// set the content type for all API endpoints and log all request info
-		r.Use(ContentTypeJSON)
+		r.Use(middleware.ContentTypeJSON)
 
 		baseRoutes := baseRegisterer.GetRoutes(
 			r,
@@ -190,7 +184,10 @@ func registerRoutes(config *config.Config, routes []*Route) {
 	policyDocLoader := policy.NewBasicPolicyDocumentLoader(config.Repo.Project())
 
 	// set up logging middleware to log information about the request
-	loggerMw := &RequestLoggerMiddleware{config.Logger}
+	loggerMw := middleware.NewRequestLoggerMiddleware(config.Logger)
+
+	// websocket middleware for upgrading requests
+	websocketMw := middleware.NewWebsocketMiddleware(config)
 
 	for _, route := range routes {
 		atomicGroup := route.Router.Group(nil)
@@ -232,6 +229,10 @@ func registerRoutes(config *config.Config, routes []*Route) {
 			atomicGroup.Use(loggerMw.Middleware)
 		}
 
+		if route.Endpoint.Metadata.IsWebsocket {
+			atomicGroup.Use(websocketMw.Middleware)
+		}
+
 		atomicGroup.Method(
 			string(route.Endpoint.Metadata.Method),
 			route.Endpoint.Metadata.Path.RelativePath,
@@ -239,73 +240,3 @@ func registerRoutes(config *config.Config, routes []*Route) {
 		)
 	}
 }
-
-// ContentTypeJSON sets the content type for requests to application/json
-func ContentTypeJSON(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set("Content-Type", "application/json;charset=utf8")
-		next.ServeHTTP(w, r)
-	})
-}
-
-type requestLoggerResponseWriter struct {
-	http.ResponseWriter
-	statusCode int
-}
-
-func newRequestLoggerResponseWriter(w http.ResponseWriter) *requestLoggerResponseWriter {
-	return &requestLoggerResponseWriter{w, http.StatusOK}
-}
-
-func (rw *requestLoggerResponseWriter) WriteHeader(code int) {
-	rw.statusCode = code
-	rw.ResponseWriter.WriteHeader(code)
-}
-
-func (rw *requestLoggerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	h, ok := rw.ResponseWriter.(http.Hijacker)
-	if !ok {
-		return nil, nil, errors.New("ResponseWriter Interface does not support hijacking")
-	}
-	return h.Hijack()
-}
-
-type RequestLoggerMiddleware struct {
-	logger *logger.Logger
-}
-
-func (mw *RequestLoggerMiddleware) Middleware(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		start := time.Now()
-		rw := newRequestLoggerResponseWriter(w)
-
-		next.ServeHTTP(rw, r)
-
-		latency := time.Since(start)
-
-		event := mw.logger.Info().Dur("latency", latency).Int("status", rw.statusCode)
-
-		logger.AddLoggingContextScopes(r.Context(), event)
-		logger.AddLoggingRequestMeta(r, event)
-
-		event.Send()
-	})
-}
-
-type PanicMiddleware struct {
-	config *config.Config
-}
-
-func (pmw *PanicMiddleware) Middleware(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		defer func() {
-			err := recover()
-
-			if err != nil {
-				apierrors.HandleAPIError(pmw.config, w, r, apierrors.NewErrInternal(fmt.Errorf("%v", err)))
-			}
-		}()
-
-		next.ServeHTTP(w, r)
-	})
-}

+ 16 - 13
api/server/shared/apierrors/errors.go

@@ -95,6 +95,7 @@ func HandleAPIError(
 	w http.ResponseWriter,
 	r *http.Request,
 	err RequestError,
+	writeErr bool,
 ) {
 	extErrorStr := err.ExternalError()
 
@@ -116,24 +117,26 @@ func HandleAPIError(
 		config.Alerter.SendAlert(r.Context(), err, data)
 	}
 
-	// send the external error
-	resp := &types.ExternalError{
-		Error: extErrorStr,
-	}
+	if writeErr {
+		// send the external error
+		resp := &types.ExternalError{
+			Error: extErrorStr,
+		}
 
-	// write the status code
-	w.WriteHeader(err.GetStatusCode())
+		// write the status code
+		w.WriteHeader(err.GetStatusCode())
 
-	writerErr := json.NewEncoder(w).Encode(resp)
+		writerErr := json.NewEncoder(w).Encode(resp)
 
-	if writerErr != nil {
-		event := config.Logger.Error().
-			Err(writerErr)
+		if writerErr != nil {
+			event := config.Logger.Error().
+				Err(writerErr)
 
-		logger.AddLoggingContextScopes(r.Context(), event)
-		logger.AddLoggingRequestMeta(r, event)
+			logger.AddLoggingContextScopes(r.Context(), event)
+			logger.AddLoggingRequestMeta(r, event)
 
-		event.Send()
+			event.Send()
+		}
 	}
 
 	return

+ 1 - 1
api/server/shared/apitest/request.go

@@ -65,7 +65,7 @@ func (f *failingDecoderValidator) DecodeAndValidate(
 	r *http.Request,
 	v interface{},
 ) (ok bool) {
-	apierrors.HandleAPIError(f.config, w, r, apierrors.NewErrInternal(fmt.Errorf("fake error")))
+	apierrors.HandleAPIError(f.config, w, r, apierrors.NewErrInternal(fmt.Errorf("fake error")), true)
 	return false
 }
 

+ 1 - 1
api/server/shared/config/config.go

@@ -2,9 +2,9 @@ package config
 
 import (
 	"github.com/gorilla/sessions"
-	"github.com/gorilla/websocket"
 	"github.com/porter-dev/porter/api/server/shared/apierrors/alerter"
 	"github.com/porter-dev/porter/api/server/shared/config/env"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/internal/analytics"
 	"github.com/porter-dev/porter/internal/auth/token"
 	"github.com/porter-dev/porter/internal/helm/urlcache"

+ 9 - 6
api/server/shared/config/loader/loader.go

@@ -5,10 +5,11 @@ import (
 	"net/http"
 	"strconv"
 
-	"github.com/gorilla/websocket"
+	gorillaws "github.com/gorilla/websocket"
 	"github.com/porter-dev/porter/api/server/shared/apierrors/alerter"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/server/shared/config/env"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/internal/adapter"
 	"github.com/porter-dev/porter/internal/analytics"
 	"github.com/porter-dev/porter/internal/auth/sessionstore"
@@ -165,11 +166,13 @@ func (e *EnvConfigLoader) LoadConfig() (res *config.Config, err error) {
 	}
 
 	res.WSUpgrader = &websocket.Upgrader{
-		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			origin := r.Header.Get("Origin")
-			return origin == sc.ServerURL
+		WSUpgrader: &gorillaws.Upgrader{
+			ReadBufferSize:  1024,
+			WriteBufferSize: 1024,
+			CheckOrigin: func(r *http.Request) bool {
+				origin := r.Header.Get("Origin")
+				return origin == sc.ServerURL
+			},
 		},
 	}
 

+ 2 - 2
api/server/shared/reader.go

@@ -38,13 +38,13 @@ func (j *DefaultRequestDecoderValidator) DecodeAndValidate(
 
 	// decode the request parameters (body and query)
 	if requestErr = j.decoder.Decode(v, r); requestErr != nil {
-		apierrors.HandleAPIError(j.config, w, r, requestErr)
+		apierrors.HandleAPIError(j.config, w, r, requestErr, true)
 		return false
 	}
 
 	// validate the request object
 	if requestErr = j.validator.Validate(v); requestErr != nil {
-		apierrors.HandleAPIError(j.config, w, r, requestErr)
+		apierrors.HandleAPIError(j.config, w, r, requestErr, true)
 		return false
 	}
 

+ 78 - 0
api/server/shared/websocket/response_writer.go

@@ -0,0 +1,78 @@
+package websocket
+
+import (
+	"errors"
+	"net/http"
+	"syscall"
+
+	"github.com/gorilla/websocket"
+)
+
+type WebsocketSafeReadWriter struct {
+	conn *websocket.Conn
+}
+
+func (w *WebsocketSafeReadWriter) WriteJSONWithChannel(v interface{}, errorChan chan<- error) {
+	err := w.conn.WriteJSON(v)
+
+	if err != nil {
+		if errOr(err, websocket.ErrCloseSent, syscall.EPIPE, syscall.ECONNRESET) {
+			// if close has been sent, or error is broken pipe error or connection reset, we want to
+			// send a message to the error channel to ensure closure but we ignore the error
+			errorChan <- nil
+		} else if err != nil {
+			errorChan <- err
+		}
+	}
+}
+
+func (w *WebsocketSafeReadWriter) Write(data []byte) (int, error) {
+	err := w.conn.WriteMessage(websocket.TextMessage, data)
+
+	if err != nil {
+		if errOr(err, websocket.ErrCloseSent, syscall.EPIPE, syscall.ECONNRESET) {
+			// if close has been sent, or error is broken pipe error or connection reset, we want to
+			// send a message to the error channel to ensure closure but we ignore the error
+			return 0, nil
+		} else if err != nil {
+			return 0, err
+		}
+	}
+
+	return len(data), nil
+}
+
+func (w *WebsocketSafeReadWriter) ReadMessage() (messageType int, p []byte, err error) {
+	return w.conn.ReadMessage()
+}
+
+type WebsocketResponseWriter struct {
+	conn       *websocket.Conn
+	safeWriter *WebsocketSafeReadWriter
+}
+
+// no HTTP headers in websocket protocol
+func (w *WebsocketResponseWriter) Header() http.Header {
+	return nil
+}
+
+// Write attempts to write a message to the websocket connection
+func (w *WebsocketResponseWriter) Write(data []byte) (int, error) {
+	return w.safeWriter.Write(data)
+}
+
+// no-op; no HTTP headers in websocket protocol
+func (w *WebsocketResponseWriter) WriteHeader(statusCode int) {
+	return
+}
+
+// helper that returns true when `err` matches any of the candidates
+func errOr(err error, candidates ...error) bool {
+	res := false
+
+	for _, cErr := range candidates {
+		res = res || errors.Is(err, cErr)
+	}
+
+	return res
+}

+ 34 - 0
api/server/shared/websocket/upgrader.go

@@ -0,0 +1,34 @@
+package websocket
+
+import (
+	"fmt"
+	"net/http"
+
+	"github.com/gorilla/websocket"
+)
+
+type Upgrader struct {
+	WSUpgrader *websocket.Upgrader
+}
+
+var UpgraderCheckOriginErr = fmt.Errorf("request origin not allowed by Upgrader.CheckOrigin")
+
+func (u *Upgrader) Upgrade(
+	w http.ResponseWriter,
+	r *http.Request,
+	responseHeader http.Header,
+) (*websocket.Conn, http.ResponseWriter, *WebsocketSafeReadWriter, error) {
+	// we manually call CheckOrigin and pass a specific error to the client in this case
+	check := u.WSUpgrader.CheckOrigin(r)
+
+	if !check {
+		return nil, nil, nil, UpgraderCheckOriginErr
+	}
+
+	conn, err := u.WSUpgrader.Upgrade(w, r, responseHeader)
+
+	safeWriter := &WebsocketSafeReadWriter{conn}
+	rw := &WebsocketResponseWriter{conn, safeWriter}
+
+	return conn, rw, safeWriter, err
+}

+ 3 - 3
api/server/shared/writer.go

@@ -27,11 +27,11 @@ func NewDefaultResultWriter(conf *config.Config) ResultWriter {
 func (j *DefaultResultWriter) WriteResult(w http.ResponseWriter, r *http.Request, v interface{}) {
 	err := json.NewEncoder(w).Encode(v)
 
-	if errors.Is(err, syscall.EPIPE) {
-		// broken pipe error, ignore. This means the client closed the connection while
+	if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) {
+		// either a broken pipe error or econnreset, ignore. This means the client closed the connection while
 		// the server was sending bytes.
 		return
 	} else if err != nil {
-		apierrors.HandleAPIError(j.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(j.config, w, r, apierrors.NewErrInternal(err), true)
 	}
 }

+ 5 - 0
api/types/request.go

@@ -60,6 +60,9 @@ type APIRequestMetadata struct {
 
 	// Whether the endpoint should log
 	Quiet bool
+
+	// Whether the endpoint upgrades to a websocket
+	IsWebsocket bool
 }
 
 const RequestScopeCtxKey = "requestscopes"
@@ -68,3 +71,5 @@ type RequestAction struct {
 	Verb     APIVerb
 	Resource NameOrUInt
 }
+
+var RequestCtxWebsocketKey = "websocket"

+ 117 - 39
internal/kubernetes/agent.go

@@ -13,7 +13,10 @@ import (
 	"strings"
 	"time"
 
+	goerrors "errors"
+
 	"github.com/porter-dev/porter/api/server/shared/config/env"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner/aws"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner/aws/ecr"
@@ -32,7 +35,6 @@ import (
 
 	errors2 "errors"
 
-	"github.com/gorilla/websocket"
 	"github.com/porter-dev/porter/internal/helm/grapher"
 	appsv1 "k8s.io/api/apps/v1"
 	batchv1 "k8s.io/api/batch/v1"
@@ -41,9 +43,11 @@ import (
 	v1beta1 "k8s.io/api/extensions/v1beta1"
 	"k8s.io/apimachinery/pkg/api/errors"
 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+	"k8s.io/apimachinery/pkg/fields"
 	"k8s.io/apimachinery/pkg/runtime"
 	"k8s.io/apimachinery/pkg/runtime/schema"
 	"k8s.io/apimachinery/pkg/types"
+	"k8s.io/apimachinery/pkg/watch"
 	"k8s.io/cli-runtime/pkg/genericclioptions"
 	"k8s.io/client-go/informers"
 	"k8s.io/client-go/kubernetes"
@@ -367,6 +371,14 @@ func (a *Agent) GetIngress(namespace string, name string) (*v1beta1.Ingress, err
 
 var IsNotFoundError = fmt.Errorf("not found")
 
+type BadRequestError struct {
+	msg string
+}
+
+func (e *BadRequestError) Error() string {
+	return e.msg
+}
+
 // GetDeployment gets the deployment given the name and namespace
 func (a *Agent) GetDeployment(c grapher.Object) (*appsv1.Deployment, error) {
 	res, err := a.Clientset.AppsV1().Deployments(c.Namespace).Get(
@@ -508,7 +520,7 @@ func (a *Agent) DeletePod(namespace string, name string) error {
 }
 
 // GetPodLogs streams real-time logs from a given pod.
-func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn) error {
+func (a *Agent) GetPodLogs(namespace string, name string, rw *websocket.WebsocketSafeReadWriter) error {
 	// get the pod to read in the list of contains
 	pod, err := a.Clientset.CoreV1().Pods(namespace).Get(
 		context.Background(),
@@ -522,6 +534,19 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 		return fmt.Errorf("Cannot get logs from pod %s: %s", name, err.Error())
 	}
 
+	// see if container is ready and able to open a stream. If not, wait for container
+	// to be ready.
+	err, isExited := a.waitForPod(pod)
+
+	if err != nil && goerrors.Is(err, IsNotFoundError) {
+		return IsNotFoundError
+	} else if err != nil {
+		return fmt.Errorf("Cannot get logs from pod %s: %s", name, err.Error())
+	} else if isExited {
+		// if exited, we return nil and simply close the stream
+		return nil
+	}
+
 	container := pod.Spec.Containers[0].Name
 
 	tails := int64(400)
@@ -537,9 +562,14 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 
 	podLogs, err := req.Stream(context.TODO())
 
-	if err != nil {
+	// in the case of bad request errors, such as if the pod is stuck in "ContainerCreating",
+	// we'd like to pass this through to the client.
+	if err != nil && errors.IsBadRequest(err) {
+		return &BadRequestError{err.Error()}
+	} else if err != nil {
 		return fmt.Errorf("Cannot open log stream for pod %s: %s", name, err.Error())
 	}
+
 	defer podLogs.Close()
 
 	r := bufio.NewReader(podLogs)
@@ -548,8 +578,7 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 	go func() {
 		// listens for websocket closing handshake
 		for {
-			if _, _, err := conn.ReadMessage(); err != nil {
-				defer conn.Close()
+			if _, _, err := rw.ReadMessage(); err != nil {
 				errorchan <- nil
 				return
 			}
@@ -566,7 +595,7 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 			}
 
 			bytes, err := r.ReadBytes('\n')
-			if writeErr := conn.WriteMessage(websocket.TextMessage, bytes); writeErr != nil {
+			if _, writeErr := rw.Write(bytes); writeErr != nil {
 				errorchan <- writeErr
 				return
 			}
@@ -669,7 +698,7 @@ func (a *Agent) RunWebsocketTask(task func() error) error {
 
 // StreamControllerStatus streams controller status. Supports Deployment, StatefulSet, ReplicaSet, and DaemonSet
 // TODO: Support Jobs
-func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, selectors string) error {
+func (a *Agent) StreamControllerStatus(kind string, selectors string, rw *websocket.WebsocketSafeReadWriter) error {
 
 	run := func() error {
 		// selectors is an array of max length 1. StreamControllerStatus accepts calls without the selectors argument.
@@ -723,10 +752,7 @@ func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, select
 					Object:    newObj,
 					Kind:      strings.ToLower(kind),
 				}
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			AddFunc: func(obj interface{}) {
 				msg := Message{
@@ -734,11 +760,7 @@ func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, select
 					Object:    obj,
 					Kind:      strings.ToLower(kind),
 				}
-
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			DeleteFunc: func(obj interface{}) {
 				msg := Message{
@@ -746,19 +768,14 @@ func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, select
 					Object:    obj,
 					Kind:      strings.ToLower(kind),
 				}
-
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 		})
 
 		go func() {
 			// listens for websocket closing handshake
 			for {
-				if _, _, err := conn.ReadMessage(); err != nil {
-					conn.Close()
+				if _, _, err := rw.ReadMessage(); err != nil {
 					errorchan <- nil
 					return
 				}
@@ -847,8 +864,7 @@ func parseSecretToHelmRelease(secret v1.Secret, chartList []string) (*rspb.Relea
 	return helm_object, false, nil
 }
 
-func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chartList []string, selectors string) error {
-
+func (a *Agent) StreamHelmReleases(namespace string, chartList []string, selectors string, rw *websocket.WebsocketSafeReadWriter) error {
 	run := func() error {
 		tweakListOptionsFunc := func(options *metav1.ListOptions) {
 			options.LabelSelector = selectors
@@ -898,10 +914,7 @@ func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chart
 					Object:    helm_object,
 				}
 
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			AddFunc: func(obj interface{}) {
 				secretObj, ok := obj.(*v1.Secret)
@@ -927,10 +940,7 @@ func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chart
 					Object:    helm_object,
 				}
 
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			DeleteFunc: func(obj interface{}) {
 				secretObj, ok := obj.(*v1.Secret)
@@ -956,18 +966,14 @@ func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chart
 					Object:    helm_object,
 				}
 
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 		})
 
 		go func() {
 			// listens for websocket closing handshake
 			for {
-				if _, _, err := conn.ReadMessage(); err != nil {
-					conn.Close()
+				if _, _, err := rw.ReadMessage(); err != nil {
 					errorchan <- nil
 					return
 				}
@@ -1343,3 +1349,75 @@ func (a *Agent) CreateImagePullSecrets(
 
 	return res, nil
 }
+
+// helper that waits for pod to be ready
+func (a *Agent) waitForPod(pod *v1.Pod) (error, bool) {
+	var (
+		w   watch.Interface
+		err error
+		ok  bool
+	)
+	// immediately after creating a pod, the API may return a 404. heuristically 1
+	// second seems to be plenty.
+	watchRetries := 3
+	for i := 0; i < watchRetries; i++ {
+		selector := fields.OneTermEqualSelector("metadata.name", pod.Name).String()
+		w, err = a.Clientset.CoreV1().
+			Pods(pod.Namespace).
+			Watch(context.Background(), metav1.ListOptions{FieldSelector: selector})
+
+		if err == nil {
+			break
+		}
+		time.Sleep(time.Second)
+	}
+	if err != nil {
+		return err, false
+	}
+	defer w.Stop()
+	for {
+		select {
+		case <-time.After(time.Second * 30):
+			return goerrors.New("timed out waiting for pod"), false
+		case <-time.Tick(time.Second):
+			// poll every second in case we already missed the ready event while
+			// creating the listener.
+			pod, err = a.Clientset.CoreV1().
+				Pods(pod.Namespace).
+				Get(context.Background(), pod.Name, metav1.GetOptions{})
+
+			if err != nil && errors.IsNotFound(err) {
+				return IsNotFoundError, false
+			} else if err != nil {
+				return err, false
+			}
+
+			if isExited := isPodExited(pod); isExited || isPodReady(pod) {
+				return nil, isExited
+			}
+		case evt := <-w.ResultChan():
+			pod, ok = evt.Object.(*v1.Pod)
+			if !ok {
+				return fmt.Errorf("unexpected object type: %T", evt.Object), false
+			}
+			if isExited := isPodExited(pod); isExited || isPodReady(pod) {
+				return nil, isExited
+			}
+		}
+	}
+}
+
+func isPodReady(pod *v1.Pod) bool {
+	ready := false
+	conditions := pod.Status.Conditions
+	for i := range conditions {
+		if conditions[i].Type == v1.PodReady {
+			ready = pod.Status.Conditions[i].Status == v1.ConditionTrue
+		}
+	}
+	return ready
+}
+
+func isPodExited(pod *v1.Pod) bool {
+	return pod.Status.Phase == v1.PodSucceeded || pod.Status.Phase == v1.PodFailed
+}

+ 4 - 10
internal/kubernetes/provisioner/resource_stream.go

@@ -4,20 +4,17 @@ import (
 	"context"
 
 	redis "github.com/go-redis/redis/v8"
-	"github.com/gorilla/websocket"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 )
 
 // ResourceStream performs an XREAD operation on the given stream and outputs it to the given websocket conn.
-func ResourceStream(client *redis.Client, streamName string, conn *websocket.Conn) error {
+func ResourceStream(client *redis.Client, streamName string, rw *websocket.WebsocketSafeReadWriter) error {
 	errorchan := make(chan error)
 
 	go func() {
 		// listens for websocket closing handshake
 		for {
-			_, _, err := conn.ReadMessage()
-
-			if err != nil {
-				defer conn.Close()
+			if _, _, err := rw.ReadMessage(); err != nil {
 				errorchan <- nil
 				return
 			}
@@ -43,10 +40,7 @@ func ResourceStream(client *redis.Client, streamName string, conn *websocket.Con
 			messages := xstream[0].Messages
 			lastID = messages[len(messages)-1].ID
 
-			if writeErr := conn.WriteJSON(messages); writeErr != nil {
-				errorchan <- writeErr
-				return
-			}
+			rw.WriteJSONWithChannel(messages, errorchan)
 		}
 	}()