ソースを参照

consolidate websocket logic

Alexander Belanger 4 年 前
コミット
73c77ea708

+ 1 - 1
api/server/authn/handler.go

@@ -153,7 +153,7 @@ func (authn *AuthN) nextWithUserID(w http.ResponseWriter, r *http.Request, userI
 func (authn *AuthN) sendForbiddenError(err error, w http.ResponseWriter, r *http.Request) {
 	reqErr := apierrors.NewErrForbidden(err)
 
-	apierrors.HandleAPIError(authn.config, w, r, reqErr)
+	apierrors.HandleAPIError(authn.config, w, r, reqErr, true)
 }
 
 var errInvalidToken = fmt.Errorf("authorization header exists, but token is not valid")

+ 2 - 2
api/server/authz/cluster.go

@@ -52,9 +52,9 @@ func (p *ClusterScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("cluster with id %d not found in project %d", clusterID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 2
api/server/authz/git_installation.go

@@ -46,12 +46,12 @@ func (p *GitInstallationScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *ht
 	gitInstallation, err := p.config.Repo.GithubAppInstallation().ReadGithubAppInstallationByInstallationID(gitInstallationID)
 
 	if err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 
 	if err := p.doesUserHaveGitInstallationAccess(user.GithubAppIntegrationID, gitInstallationID); err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 

+ 2 - 2
api/server/authz/helm_repo.go

@@ -45,9 +45,9 @@ func (p *HelmRepoScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Requ
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("helm repo with id %d not found in project %d", helmRepoID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 2
api/server/authz/infra.go

@@ -45,9 +45,9 @@ func (p *InfraScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("infra with id %d not found in project %d", infraID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 2
api/server/authz/invite.go

@@ -45,9 +45,9 @@ func (p *InviteScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("invite with id %d not found in project %d", inviteID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 3 - 2
api/server/authz/policy.go

@@ -43,7 +43,7 @@ func (h *PolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	reqScopes, reqErr := getRequestActionForEndpoint(r, h.endpointMeta)
 
 	if reqErr != nil {
-		apierrors.HandleAPIError(h.config, w, r, reqErr)
+		apierrors.HandleAPIError(h.config, w, r, reqErr, true)
 		return
 	}
 
@@ -54,7 +54,7 @@ func (h *PolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	policyDocs, reqErr := h.loader.LoadPolicyDocuments(user.ID, projID)
 
 	if reqErr != nil {
-		apierrors.HandleAPIError(h.config, w, r, reqErr)
+		apierrors.HandleAPIError(h.config, w, r, reqErr, true)
 		return
 	}
 
@@ -67,6 +67,7 @@ func (h *PolicyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			w,
 			r,
 			apierrors.NewErrForbidden(fmt.Errorf("policy forbids action for user %d in project %d", user.ID, projID)),
+			true,
 		)
 
 		return

+ 2 - 2
api/server/authz/project.go

@@ -43,12 +43,12 @@ func (p *ProjectScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("project not found with id %d", projID),
-			))
+			), true)
 
 			return
 		}
 
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 

+ 2 - 2
api/server/authz/registry.go

@@ -45,9 +45,9 @@ func (p *RegistryScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Requ
 		if err == gorm.ErrRecordNotFound {
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrForbidden(
 				fmt.Errorf("registry with id %d not found in project %d", registryID, proj.ID),
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 3 - 3
api/server/authz/release.go

@@ -40,7 +40,7 @@ func (p *ReleaseScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 	helmAgent, err := p.agentGetter.GetHelmAgent(r, cluster, "")
 
 	if err != nil {
-		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		return
 	}
 
@@ -60,9 +60,9 @@ func (p *ReleaseScopedMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reque
 			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrPassThroughToClient(
 				fmt.Errorf("release not found"),
 				http.StatusNotFound,
-			))
+			), true)
 		} else {
-			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err))
+			apierrors.HandleAPIError(p.config, w, r, apierrors.NewErrInternal(err), true)
 		}
 
 		return

+ 2 - 10
api/server/handlers/cluster/stream_helm_release.go

@@ -8,6 +8,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
 )
@@ -29,16 +30,7 @@ func NewStreamHelmReleaseHandler(
 }
 
 func (c *StreamHelmReleaseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	w = newRW
-	defer conn.Close()
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	request := &types.StreamHelmReleaseRequest{}
 
 	if ok := c.DecodeAndValidate(w, r, request); !ok {

+ 2 - 10
api/server/handlers/cluster/stream_status.go

@@ -9,6 +9,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/server/shared/requestutils"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/models"
 )
@@ -30,16 +31,7 @@ func NewStreamStatusHandler(
 }
 
 func (c *StreamStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	w = newRW
-	defer conn.Close()
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	request := &types.StreamStatusRequest{}
 
 	if ok := c.DecodeAndValidate(w, r, request); !ok {

+ 6 - 1
api/server/handlers/handler.go

@@ -16,6 +16,7 @@ type PorterHandler interface {
 	Config() *config.Config
 	Repo() repository.Repository
 	HandleAPIError(w http.ResponseWriter, r *http.Request, err apierrors.RequestError)
+	HandleAPIErrorNoWrite(w http.ResponseWriter, r *http.Request, err apierrors.RequestError)
 	PopulateOAuthSession(w http.ResponseWriter, r *http.Request, state string, isProject bool) error
 }
 
@@ -57,7 +58,11 @@ func (d *DefaultPorterHandler) Repo() repository.Repository {
 }
 
 func (d *DefaultPorterHandler) HandleAPIError(w http.ResponseWriter, r *http.Request, err apierrors.RequestError) {
-	apierrors.HandleAPIError(d.Config(), w, r, err)
+	apierrors.HandleAPIError(d.Config(), w, r, err, true)
+}
+
+func (d *DefaultPorterHandler) HandleAPIErrorNoWrite(w http.ResponseWriter, r *http.Request, err apierrors.RequestError) {
+	apierrors.HandleAPIError(d.Config(), w, r, err, false)
 }
 
 func (d *DefaultPorterHandler) WriteResult(w http.ResponseWriter, r *http.Request, v interface{}) {

+ 2 - 10
api/server/handlers/infra/stream_logs.go

@@ -7,6 +7,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared"
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/adapter"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner"
@@ -27,16 +28,7 @@ func NewInfraStreamLogsHandler(
 }
 
 func (c *InfraStreamLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	w = newRW
-	defer conn.Close()
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	infra, _ := r.Context().Value(types.InfraScope).(*models.Infra)
 
 	client, err := adapter.NewRedisClient(c.Config().RedisConf)

+ 2 - 10
api/server/handlers/namespace/stream_pod_logs.go

@@ -11,6 +11,7 @@ import (
 	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/server/shared/requestutils"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/api/types"
 	"github.com/porter-dev/porter/internal/kubernetes"
 	"github.com/porter-dev/porter/internal/models"
@@ -33,16 +34,7 @@ func NewStreamPodLogsHandler(
 }
 
 func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
-
-	if err != nil {
-		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
-		return
-	}
-
-	w = newRW
-	defer conn.Close()
-
+	safeRW := r.Context().Value(types.RequestCtxWebsocketKey).(*websocket.WebsocketSafeReadWriter)
 	namespace := r.Context().Value(types.NamespaceScope).(string)
 	name, _ := requestutils.GetURLParamString(r, types.URLParamPodName)
 

+ 2 - 0
api/server/router/cluster.go

@@ -518,6 +518,7 @@ func getClusterRoutes(
 				types.ProjectScope,
 				types.ClusterScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 
@@ -551,6 +552,7 @@ func getClusterRoutes(
 				types.ProjectScope,
 				types.ClusterScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 

+ 1 - 0
api/server/router/infra.go

@@ -121,6 +121,7 @@ func getInfraRoutes(
 				types.ProjectScope,
 				types.InfraScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 

+ 11 - 0
api/server/router/middleware/content_type_json.go

@@ -0,0 +1,11 @@
+package middleware
+
+import "net/http"
+
+// ContentTypeJSON sets the content type for requests to application/json
+func ContentTypeJSON(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("Content-Type", "application/json;charset=utf8")
+		next.ServeHTTP(w, r)
+	})
+}

+ 31 - 0
api/server/router/middleware/panic.go

@@ -0,0 +1,31 @@
+package middleware
+
+import (
+	"fmt"
+	"net/http"
+
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+)
+
+type PanicMiddleware struct {
+	config *config.Config
+}
+
+func NewPanicMiddleware(config *config.Config) *PanicMiddleware {
+	return &PanicMiddleware{config}
+}
+
+func (pmw *PanicMiddleware) Middleware(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		defer func() {
+			err := recover()
+
+			if err != nil {
+				apierrors.HandleAPIError(pmw.config, w, r, apierrors.NewErrInternal(fmt.Errorf("%v", err)), true)
+			}
+		}()
+
+		next.ServeHTTP(w, r)
+	})
+}

+ 59 - 0
api/server/router/middleware/request_logger.go

@@ -0,0 +1,59 @@
+package middleware
+
+import (
+	"bufio"
+	"errors"
+	"net"
+	"net/http"
+	"time"
+
+	"github.com/porter-dev/porter/internal/logger"
+)
+
+type requestLoggerResponseWriter struct {
+	http.ResponseWriter
+	statusCode int
+}
+
+func newRequestLoggerResponseWriter(w http.ResponseWriter) *requestLoggerResponseWriter {
+	return &requestLoggerResponseWriter{w, http.StatusOK}
+}
+
+func (rw *requestLoggerResponseWriter) WriteHeader(code int) {
+	rw.statusCode = code
+	rw.ResponseWriter.WriteHeader(code)
+}
+
+func (rw *requestLoggerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	h, ok := rw.ResponseWriter.(http.Hijacker)
+	if !ok {
+		return nil, nil, errors.New("ResponseWriter Interface does not support hijacking")
+	}
+	return h.Hijack()
+}
+
+type RequestLoggerMiddleware struct {
+	logger *logger.Logger
+}
+
+func NewRequestLoggerMiddleware(logger *logger.Logger) *RequestLoggerMiddleware {
+	return &RequestLoggerMiddleware{logger}
+}
+
+func (mw *RequestLoggerMiddleware) Middleware(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		start := time.Now()
+		rw := newRequestLoggerResponseWriter(w)
+
+		next.ServeHTTP(rw, r)
+
+		latency := time.Since(start)
+
+		event := mw.logger.Info().Dur("latency", latency).Int("status", rw.statusCode)
+
+		logger.AddLoggingContextScopes(r.Context(), event)
+		logger.AddLoggingRequestMeta(r, event)
+
+		event.Send()
+	})
+}

+ 45 - 0
api/server/router/middleware/websocket.go

@@ -0,0 +1,45 @@
+package middleware
+
+import (
+	"context"
+	"errors"
+	"net/http"
+
+	"github.com/porter-dev/porter/api/server/shared/apierrors"
+	"github.com/porter-dev/porter/api/server/shared/config"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
+	"github.com/porter-dev/porter/api/types"
+)
+
+type WebsocketMiddleware struct {
+	config *config.Config
+}
+
+func NewWebsocketMiddleware(config *config.Config) *WebsocketMiddleware {
+	return &WebsocketMiddleware{config}
+}
+
+func (wm *WebsocketMiddleware) Middleware(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		conn, newRW, safeRW, err := wm.config.WSUpgrader.Upgrade(w, r, nil)
+
+		if err != nil {
+			if errors.Is(err, websocket.UpgraderCheckOriginErr) {
+				apierrors.HandleAPIError(wm.config, w, r, apierrors.NewErrForbidden(err), true)
+				return
+			} else {
+				apierrors.HandleAPIError(wm.config, w, r, apierrors.NewErrInternal(err), false)
+				return
+			}
+		}
+
+		w = newRW
+		defer conn.Close()
+
+		ctx := r.Context()
+		ctx = context.WithValue(ctx, types.RequestCtxWebsocketKey, safeRW)
+
+		r = r.Clone(ctx)
+		next.ServeHTTP(w, r)
+	})
+}

+ 1 - 0
api/server/router/namespace.go

@@ -283,6 +283,7 @@ func getNamespaceRoutes(
 				types.ClusterScope,
 				types.NamespaceScope,
 			},
+			IsWebsocket: true,
 		},
 	)
 

+ 11 - 80
api/server/router/router.go

@@ -1,25 +1,19 @@
 package router
 
 import (
-	"bufio"
-	"errors"
-	"fmt"
-	"net"
 	"net/http"
 	"os"
 	"path"
 	"strings"
-	"time"
 
 	"github.com/go-chi/chi"
 	"github.com/porter-dev/porter/api/server/authn"
 	"github.com/porter-dev/porter/api/server/authz"
 	"github.com/porter-dev/porter/api/server/authz/policy"
+	"github.com/porter-dev/porter/api/server/router/middleware"
 	"github.com/porter-dev/porter/api/server/shared"
-	"github.com/porter-dev/porter/api/server/shared/apierrors"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/types"
-	"github.com/porter-dev/porter/internal/logger"
 )
 
 func NewAPIRouter(config *config.Config) *chi.Mux {
@@ -54,14 +48,14 @@ func NewAPIRouter(config *config.Config) *chi.Mux {
 	)
 
 	userRegisterer := NewUserScopedRegisterer(projRegisterer)
-	panicMW := &PanicMiddleware{config}
+	panicMW := middleware.NewPanicMiddleware(config)
 
 	r.Route("/api", func(r chi.Router) {
 		// set panic middleware for all API endpoints to catch panics
 		r.Use(panicMW.Middleware)
 
 		// set the content type for all API endpoints and log all request info
-		r.Use(ContentTypeJSON)
+		r.Use(middleware.ContentTypeJSON)
 
 		baseRoutes := baseRegisterer.GetRoutes(
 			r,
@@ -190,7 +184,10 @@ func registerRoutes(config *config.Config, routes []*Route) {
 	policyDocLoader := policy.NewBasicPolicyDocumentLoader(config.Repo.Project())
 
 	// set up logging middleware to log information about the request
-	loggerMw := &RequestLoggerMiddleware{config.Logger}
+	loggerMw := middleware.NewRequestLoggerMiddleware(config.Logger)
+
+	// websocket middleware for upgrading requests
+	websocketMw := middleware.NewWebsocketMiddleware(config)
 
 	for _, route := range routes {
 		atomicGroup := route.Router.Group(nil)
@@ -232,6 +229,10 @@ func registerRoutes(config *config.Config, routes []*Route) {
 			atomicGroup.Use(loggerMw.Middleware)
 		}
 
+		if route.Endpoint.Metadata.IsWebsocket {
+			atomicGroup.Use(websocketMw.Middleware)
+		}
+
 		atomicGroup.Method(
 			string(route.Endpoint.Metadata.Method),
 			route.Endpoint.Metadata.Path.RelativePath,
@@ -239,73 +240,3 @@ func registerRoutes(config *config.Config, routes []*Route) {
 		)
 	}
 }
-
-// ContentTypeJSON sets the content type for requests to application/json
-func ContentTypeJSON(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		w.Header().Set("Content-Type", "application/json;charset=utf8")
-		next.ServeHTTP(w, r)
-	})
-}
-
-type requestLoggerResponseWriter struct {
-	http.ResponseWriter
-	statusCode int
-}
-
-func newRequestLoggerResponseWriter(w http.ResponseWriter) *requestLoggerResponseWriter {
-	return &requestLoggerResponseWriter{w, http.StatusOK}
-}
-
-func (rw *requestLoggerResponseWriter) WriteHeader(code int) {
-	rw.statusCode = code
-	rw.ResponseWriter.WriteHeader(code)
-}
-
-func (rw *requestLoggerResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	h, ok := rw.ResponseWriter.(http.Hijacker)
-	if !ok {
-		return nil, nil, errors.New("ResponseWriter Interface does not support hijacking")
-	}
-	return h.Hijack()
-}
-
-type RequestLoggerMiddleware struct {
-	logger *logger.Logger
-}
-
-func (mw *RequestLoggerMiddleware) Middleware(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		start := time.Now()
-		rw := newRequestLoggerResponseWriter(w)
-
-		next.ServeHTTP(rw, r)
-
-		latency := time.Since(start)
-
-		event := mw.logger.Info().Dur("latency", latency).Int("status", rw.statusCode)
-
-		logger.AddLoggingContextScopes(r.Context(), event)
-		logger.AddLoggingRequestMeta(r, event)
-
-		event.Send()
-	})
-}
-
-type PanicMiddleware struct {
-	config *config.Config
-}
-
-func (pmw *PanicMiddleware) Middleware(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		defer func() {
-			err := recover()
-
-			if err != nil {
-				apierrors.HandleAPIError(pmw.config, w, r, apierrors.NewErrInternal(fmt.Errorf("%v", err)))
-			}
-		}()
-
-		next.ServeHTTP(w, r)
-	})
-}

+ 16 - 13
api/server/shared/apierrors/errors.go

@@ -95,6 +95,7 @@ func HandleAPIError(
 	w http.ResponseWriter,
 	r *http.Request,
 	err RequestError,
+	writeErr bool,
 ) {
 	extErrorStr := err.ExternalError()
 
@@ -116,24 +117,26 @@ func HandleAPIError(
 		config.Alerter.SendAlert(r.Context(), err, data)
 	}
 
-	// send the external error
-	resp := &types.ExternalError{
-		Error: extErrorStr,
-	}
+	if writeErr {
+		// send the external error
+		resp := &types.ExternalError{
+			Error: extErrorStr,
+		}
 
-	// write the status code
-	w.WriteHeader(err.GetStatusCode())
+		// write the status code
+		w.WriteHeader(err.GetStatusCode())
 
-	writerErr := json.NewEncoder(w).Encode(resp)
+		writerErr := json.NewEncoder(w).Encode(resp)
 
-	if writerErr != nil {
-		event := config.Logger.Error().
-			Err(writerErr)
+		if writerErr != nil {
+			event := config.Logger.Error().
+				Err(writerErr)
 
-		logger.AddLoggingContextScopes(r.Context(), event)
-		logger.AddLoggingRequestMeta(r, event)
+			logger.AddLoggingContextScopes(r.Context(), event)
+			logger.AddLoggingRequestMeta(r, event)
 
-		event.Send()
+			event.Send()
+		}
 	}
 
 	return

+ 1 - 1
api/server/shared/apitest/request.go

@@ -65,7 +65,7 @@ func (f *failingDecoderValidator) DecodeAndValidate(
 	r *http.Request,
 	v interface{},
 ) (ok bool) {
-	apierrors.HandleAPIError(f.config, w, r, apierrors.NewErrInternal(fmt.Errorf("fake error")))
+	apierrors.HandleAPIError(f.config, w, r, apierrors.NewErrInternal(fmt.Errorf("fake error")), true)
 	return false
 }
 

+ 2 - 2
api/server/shared/reader.go

@@ -38,13 +38,13 @@ func (j *DefaultRequestDecoderValidator) DecodeAndValidate(
 
 	// decode the request parameters (body and query)
 	if requestErr = j.decoder.Decode(v, r); requestErr != nil {
-		apierrors.HandleAPIError(j.config, w, r, requestErr)
+		apierrors.HandleAPIError(j.config, w, r, requestErr, true)
 		return false
 	}
 
 	// validate the request object
 	if requestErr = j.validator.Validate(v); requestErr != nil {
-		apierrors.HandleAPIError(j.config, w, r, requestErr)
+		apierrors.HandleAPIError(j.config, w, r, requestErr, true)
 		return false
 	}
 

+ 10 - 0
api/server/shared/websocket/upgrader.go

@@ -1,6 +1,7 @@
 package websocket
 
 import (
+	"fmt"
 	"net/http"
 
 	"github.com/gorilla/websocket"
@@ -10,11 +11,20 @@ type Upgrader struct {
 	WSUpgrader *websocket.Upgrader
 }
 
+var UpgraderCheckOriginErr = fmt.Errorf("request origin not allowed by Upgrader.CheckOrigin")
+
 func (u *Upgrader) Upgrade(
 	w http.ResponseWriter,
 	r *http.Request,
 	responseHeader http.Header,
 ) (*websocket.Conn, http.ResponseWriter, *WebsocketSafeReadWriter, error) {
+	// we manually call CheckOrigin and pass a specific error to the client in this case
+	check := u.WSUpgrader.CheckOrigin(r)
+
+	if !check {
+		return nil, nil, nil, UpgraderCheckOriginErr
+	}
+
 	conn, err := u.WSUpgrader.Upgrade(w, r, responseHeader)
 
 	safeWriter := &WebsocketSafeReadWriter{conn}

+ 1 - 1
api/server/shared/writer.go

@@ -32,6 +32,6 @@ func (j *DefaultResultWriter) WriteResult(w http.ResponseWriter, r *http.Request
 		// the server was sending bytes.
 		return
 	} else if err != nil {
-		apierrors.HandleAPIError(j.config, w, r, apierrors.NewErrInternal(err))
+		apierrors.HandleAPIError(j.config, w, r, apierrors.NewErrInternal(err), true)
 	}
 }

+ 5 - 0
api/types/request.go

@@ -60,6 +60,9 @@ type APIRequestMetadata struct {
 
 	// Whether the endpoint should log
 	Quiet bool
+
+	// Whether the endpoint upgrades to a websocket
+	IsWebsocket bool
 }
 
 const RequestScopeCtxKey = "requestscopes"
@@ -68,3 +71,5 @@ type RequestAction struct {
 	Verb     APIVerb
 	Resource NameOrUInt
 }
+
+var RequestCtxWebsocketKey = "websocket"