websocket.go 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. package middleware
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "github.com/porter-dev/porter/api/server/shared/apierrors"
  7. "github.com/porter-dev/porter/api/server/shared/config"
  8. "github.com/porter-dev/porter/api/server/shared/websocket"
  9. "github.com/porter-dev/porter/api/types"
  10. )
  11. type WebsocketMiddleware struct {
  12. config *config.Config
  13. }
  14. func NewWebsocketMiddleware(config *config.Config) *WebsocketMiddleware {
  15. return &WebsocketMiddleware{config}
  16. }
  17. func (wm *WebsocketMiddleware) Middleware(next http.Handler) http.Handler {
  18. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. conn, newRW, safeRW, err := wm.config.WSUpgrader.Upgrade(w, r, nil)
  20. if err != nil {
  21. if errors.Is(err, websocket.UpgraderCheckOriginErr) {
  22. apierrors.HandleAPIError(wm.config.Logger, wm.config.Alerter, w, r, apierrors.NewErrForbidden(err), true)
  23. return
  24. } else {
  25. apierrors.HandleAPIError(wm.config.Logger, wm.config.Alerter, w, r, apierrors.NewErrInternal(err), false)
  26. return
  27. }
  28. }
  29. w = newRW
  30. defer conn.Close()
  31. ctx := r.Context()
  32. ctx = context.WithValue(ctx, types.RequestCtxWebsocketKey, safeRW)
  33. r = r.Clone(ctx)
  34. next.ServeHTTP(w, r)
  35. })
  36. }