Quellcode durchsuchen

refactor websocket handlers for stability

Alexander Belanger vor 4 Jahren
Ursprung
Commit
4406e76bfc

+ 5 - 2
api/server/handlers/cluster/stream_helm_release.go

@@ -29,13 +29,16 @@ func NewStreamHelmReleaseHandler(
 }
 
 func (c *StreamHelmReleaseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
+	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
 		return
 	}
 
+	w = newRW
+	defer conn.Close()
+
 	request := &types.StreamHelmReleaseRequest{}
 
 	if ok := c.DecodeAndValidate(w, r, request); !ok {
@@ -51,7 +54,7 @@ func (c *StreamHelmReleaseHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ
 		return
 	}
 
-	err = agent.StreamHelmReleases(conn, request.Namespace, request.Charts, request.Selectors)
+	err = agent.StreamHelmReleases(request.Namespace, request.Charts, request.Selectors, safeRW)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 5 - 2
api/server/handlers/cluster/stream_status.go

@@ -30,13 +30,16 @@ func NewStreamStatusHandler(
 }
 
 func (c *StreamStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
+	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
 		return
 	}
 
+	w = newRW
+	defer conn.Close()
+
 	request := &types.StreamStatusRequest{}
 
 	if ok := c.DecodeAndValidate(w, r, request); !ok {
@@ -54,7 +57,7 @@ func (c *StreamStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 
 	kind, _ := requestutils.GetURLParamString(r, types.URLParamKind)
 
-	err = agent.StreamControllerStatus(conn, kind, request.Selectors)
+	err = agent.StreamControllerStatus(kind, request.Selectors, safeRW)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 5 - 2
api/server/handlers/infra/stream_logs.go

@@ -27,13 +27,16 @@ func NewInfraStreamLogsHandler(
 }
 
 func (c *InfraStreamLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
+	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
 		return
 	}
 
+	w = newRW
+	defer conn.Close()
+
 	infra, _ := r.Context().Value(types.InfraScope).(*models.Infra)
 
 	client, err := adapter.NewRedisClient(c.Config().RedisConf)
@@ -43,7 +46,7 @@ func (c *InfraStreamLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
 		return
 	}
 
-	err = provisioner.ResourceStream(client, infra.GetUniqueName(), conn)
+	err = provisioner.ResourceStream(client, infra.GetUniqueName(), safeRW)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 12 - 2
api/server/handlers/namespace/stream_pod_logs.go

@@ -33,13 +33,16 @@ func NewStreamPodLogsHandler(
 }
 
 func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	conn, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
+	conn, newRW, safeRW, err := c.Config().WSUpgrader.Upgrade(w, r, nil)
 
 	if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
 		return
 	}
 
+	w = newRW
+	defer conn.Close()
+
 	namespace := r.Context().Value(types.NamespaceScope).(string)
 	name, _ := requestutils.GetURLParamString(r, types.URLParamPodName)
 
@@ -52,7 +55,7 @@ func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 		return
 	}
 
-	err = agent.GetPodLogs(namespace, name, conn)
+	err = agent.GetPodLogs(namespace, name, safeRW)
 
 	if targetErr := kubernetes.IsNotFoundError; errors.Is(err, targetErr) {
 		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(
@@ -60,6 +63,13 @@ func (c *StreamPodLogsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
 			http.StatusNotFound,
 		))
 
+		return
+	} else if brErr := (kubernetes.BadRequestError{}); errors.As(err, &targetErr) {
+		c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(
+			&brErr,
+			http.StatusBadRequest,
+		))
+
 		return
 	} else if err != nil {
 		c.HandleAPIError(w, r, apierrors.NewErrInternal(err))

+ 1 - 1
api/server/shared/config/config.go

@@ -2,9 +2,9 @@ package config
 
 import (
 	"github.com/gorilla/sessions"
-	"github.com/gorilla/websocket"
 	"github.com/porter-dev/porter/api/server/shared/apierrors/alerter"
 	"github.com/porter-dev/porter/api/server/shared/config/env"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/internal/analytics"
 	"github.com/porter-dev/porter/internal/auth/token"
 	"github.com/porter-dev/porter/internal/helm/urlcache"

+ 9 - 6
api/server/shared/config/loader/loader.go

@@ -5,10 +5,11 @@ import (
 	"net/http"
 	"strconv"
 
-	"github.com/gorilla/websocket"
+	gorillaws "github.com/gorilla/websocket"
 	"github.com/porter-dev/porter/api/server/shared/apierrors/alerter"
 	"github.com/porter-dev/porter/api/server/shared/config"
 	"github.com/porter-dev/porter/api/server/shared/config/env"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/internal/adapter"
 	"github.com/porter-dev/porter/internal/analytics"
 	"github.com/porter-dev/porter/internal/auth/sessionstore"
@@ -165,11 +166,13 @@ func (e *EnvConfigLoader) LoadConfig() (res *config.Config, err error) {
 	}
 
 	res.WSUpgrader = &websocket.Upgrader{
-		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			origin := r.Header.Get("Origin")
-			return origin == sc.ServerURL
+		WSUpgrader: &gorillaws.Upgrader{
+			ReadBufferSize:  1024,
+			WriteBufferSize: 1024,
+			CheckOrigin: func(r *http.Request) bool {
+				origin := r.Header.Get("Origin")
+				return origin == sc.ServerURL
+			},
 		},
 	}
 

+ 67 - 0
api/server/shared/websocket/response_writer.go

@@ -0,0 +1,67 @@
+package websocket
+
+import (
+	"errors"
+	"net/http"
+	"syscall"
+
+	"github.com/gorilla/websocket"
+)
+
+type WebsocketSafeReadWriter struct {
+	conn *websocket.Conn
+}
+
+func (w *WebsocketSafeReadWriter) WriteJSONWithChannel(v interface{}, errorChan chan<- error) {
+	err := w.conn.WriteJSON(v)
+
+	if err != nil {
+		// we ignore broken pipe errors and connection reset errors, but we want to
+		// send a message to the error channel to ensure closure
+		if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, syscall.ECONNRESET) {
+			errorChan <- nil
+		} else if err != nil {
+			errorChan <- err
+		}
+	}
+}
+
+func (w *WebsocketSafeReadWriter) Write(data []byte) (int, error) {
+	err := w.conn.WriteMessage(websocket.TextMessage, data)
+
+	if err != nil {
+		// we ignore broken pipe errors and connection reset errors, but we want to
+		// send a message to the error channel to ensure closure
+		if !errors.Is(err, syscall.EPIPE) && !errors.Is(err, syscall.ECONNRESET) {
+			return 0, nil
+		} else if err != nil {
+			return 0, err
+		}
+	}
+
+	return len(data), nil
+}
+
+func (w *WebsocketSafeReadWriter) ReadMessage() (messageType int, p []byte, err error) {
+	return w.conn.ReadMessage()
+}
+
+type WebsocketResponseWriter struct {
+	conn       *websocket.Conn
+	safeWriter *WebsocketSafeReadWriter
+}
+
+// no HTTP headers in websocket protocol
+func (w *WebsocketResponseWriter) Header() http.Header {
+	return nil
+}
+
+// Write attempts to write a message to the websocket connection
+func (w *WebsocketResponseWriter) Write(data []byte) (int, error) {
+	return w.safeWriter.Write(data)
+}
+
+// no-op; no HTTP headers in websocket protocol
+func (w *WebsocketResponseWriter) WriteHeader(statusCode int) {
+	return
+}

+ 24 - 0
api/server/shared/websocket/upgrader.go

@@ -0,0 +1,24 @@
+package websocket
+
+import (
+	"net/http"
+
+	"github.com/gorilla/websocket"
+)
+
+type Upgrader struct {
+	WSUpgrader *websocket.Upgrader
+}
+
+func (u *Upgrader) Upgrade(
+	w http.ResponseWriter,
+	r *http.Request,
+	responseHeader http.Header,
+) (*websocket.Conn, http.ResponseWriter, *WebsocketSafeReadWriter, error) {
+	conn, err := u.WSUpgrader.Upgrade(w, r, responseHeader)
+
+	safeWriter := &WebsocketSafeReadWriter{conn}
+	rw := &WebsocketResponseWriter{conn, safeWriter}
+
+	return conn, rw, safeWriter, err
+}

+ 117 - 39
internal/kubernetes/agent.go

@@ -13,7 +13,10 @@ import (
 	"strings"
 	"time"
 
+	goerrors "errors"
+
 	"github.com/porter-dev/porter/api/server/shared/config/env"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner/aws"
 	"github.com/porter-dev/porter/internal/kubernetes/provisioner/aws/ecr"
@@ -32,7 +35,6 @@ import (
 
 	errors2 "errors"
 
-	"github.com/gorilla/websocket"
 	"github.com/porter-dev/porter/internal/helm/grapher"
 	appsv1 "k8s.io/api/apps/v1"
 	batchv1 "k8s.io/api/batch/v1"
@@ -41,9 +43,11 @@ import (
 	v1beta1 "k8s.io/api/extensions/v1beta1"
 	"k8s.io/apimachinery/pkg/api/errors"
 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+	"k8s.io/apimachinery/pkg/fields"
 	"k8s.io/apimachinery/pkg/runtime"
 	"k8s.io/apimachinery/pkg/runtime/schema"
 	"k8s.io/apimachinery/pkg/types"
+	"k8s.io/apimachinery/pkg/watch"
 	"k8s.io/cli-runtime/pkg/genericclioptions"
 	"k8s.io/client-go/informers"
 	"k8s.io/client-go/kubernetes"
@@ -367,6 +371,14 @@ func (a *Agent) GetIngress(namespace string, name string) (*v1beta1.Ingress, err
 
 var IsNotFoundError = fmt.Errorf("not found")
 
+type BadRequestError struct {
+	msg string
+}
+
+func (e *BadRequestError) Error() string {
+	return e.msg
+}
+
 // GetDeployment gets the deployment given the name and namespace
 func (a *Agent) GetDeployment(c grapher.Object) (*appsv1.Deployment, error) {
 	res, err := a.Clientset.AppsV1().Deployments(c.Namespace).Get(
@@ -508,7 +520,7 @@ func (a *Agent) DeletePod(namespace string, name string) error {
 }
 
 // GetPodLogs streams real-time logs from a given pod.
-func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn) error {
+func (a *Agent) GetPodLogs(namespace string, name string, rw *websocket.WebsocketSafeReadWriter) error {
 	// get the pod to read in the list of contains
 	pod, err := a.Clientset.CoreV1().Pods(namespace).Get(
 		context.Background(),
@@ -522,6 +534,19 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 		return fmt.Errorf("Cannot get logs from pod %s: %s", name, err.Error())
 	}
 
+	// see if container is ready and able to open a stream. If not, wait for container
+	// to be ready.
+	err, isExited := a.waitForPod(pod)
+
+	if err != nil && goerrors.Is(err, IsNotFoundError) {
+		return IsNotFoundError
+	} else if err != nil {
+		return fmt.Errorf("Cannot get logs from pod %s: %s", name, err.Error())
+	} else if isExited {
+		// if exited, we return nil and simply close the stream
+		return nil
+	}
+
 	container := pod.Spec.Containers[0].Name
 
 	tails := int64(400)
@@ -537,9 +562,14 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 
 	podLogs, err := req.Stream(context.TODO())
 
-	if err != nil {
+	// in the case of bad request errors, such as if the pod is stuck in "ContainerCreating",
+	// we'd like to pass this through to the client.
+	if err != nil && errors.IsBadRequest(err) {
+		return &BadRequestError{err.Error()}
+	} else if err != nil {
 		return fmt.Errorf("Cannot open log stream for pod %s: %s", name, err.Error())
 	}
+
 	defer podLogs.Close()
 
 	r := bufio.NewReader(podLogs)
@@ -548,8 +578,7 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 	go func() {
 		// listens for websocket closing handshake
 		for {
-			if _, _, err := conn.ReadMessage(); err != nil {
-				defer conn.Close()
+			if _, _, err := rw.ReadMessage(); err != nil {
 				errorchan <- nil
 				return
 			}
@@ -566,7 +595,7 @@ func (a *Agent) GetPodLogs(namespace string, name string, conn *websocket.Conn)
 			}
 
 			bytes, err := r.ReadBytes('\n')
-			if writeErr := conn.WriteMessage(websocket.TextMessage, bytes); writeErr != nil {
+			if _, writeErr := rw.Write(bytes); writeErr != nil {
 				errorchan <- writeErr
 				return
 			}
@@ -669,7 +698,7 @@ func (a *Agent) RunWebsocketTask(task func() error) error {
 
 // StreamControllerStatus streams controller status. Supports Deployment, StatefulSet, ReplicaSet, and DaemonSet
 // TODO: Support Jobs
-func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, selectors string) error {
+func (a *Agent) StreamControllerStatus(kind string, selectors string, rw *websocket.WebsocketSafeReadWriter) error {
 
 	run := func() error {
 		// selectors is an array of max length 1. StreamControllerStatus accepts calls without the selectors argument.
@@ -723,10 +752,7 @@ func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, select
 					Object:    newObj,
 					Kind:      strings.ToLower(kind),
 				}
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			AddFunc: func(obj interface{}) {
 				msg := Message{
@@ -734,11 +760,7 @@ func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, select
 					Object:    obj,
 					Kind:      strings.ToLower(kind),
 				}
-
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			DeleteFunc: func(obj interface{}) {
 				msg := Message{
@@ -746,19 +768,14 @@ func (a *Agent) StreamControllerStatus(conn *websocket.Conn, kind string, select
 					Object:    obj,
 					Kind:      strings.ToLower(kind),
 				}
-
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 		})
 
 		go func() {
 			// listens for websocket closing handshake
 			for {
-				if _, _, err := conn.ReadMessage(); err != nil {
-					conn.Close()
+				if _, _, err := rw.ReadMessage(); err != nil {
 					errorchan <- nil
 					return
 				}
@@ -847,8 +864,7 @@ func parseSecretToHelmRelease(secret v1.Secret, chartList []string) (*rspb.Relea
 	return helm_object, false, nil
 }
 
-func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chartList []string, selectors string) error {
-
+func (a *Agent) StreamHelmReleases(namespace string, chartList []string, selectors string, rw *websocket.WebsocketSafeReadWriter) error {
 	run := func() error {
 		tweakListOptionsFunc := func(options *metav1.ListOptions) {
 			options.LabelSelector = selectors
@@ -898,10 +914,7 @@ func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chart
 					Object:    helm_object,
 				}
 
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			AddFunc: func(obj interface{}) {
 				secretObj, ok := obj.(*v1.Secret)
@@ -927,10 +940,7 @@ func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chart
 					Object:    helm_object,
 				}
 
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 			DeleteFunc: func(obj interface{}) {
 				secretObj, ok := obj.(*v1.Secret)
@@ -956,18 +966,14 @@ func (a *Agent) StreamHelmReleases(conn *websocket.Conn, namespace string, chart
 					Object:    helm_object,
 				}
 
-				if writeErr := conn.WriteJSON(msg); writeErr != nil {
-					errorchan <- writeErr
-					return
-				}
+				rw.WriteJSONWithChannel(msg, errorchan)
 			},
 		})
 
 		go func() {
 			// listens for websocket closing handshake
 			for {
-				if _, _, err := conn.ReadMessage(); err != nil {
-					conn.Close()
+				if _, _, err := rw.ReadMessage(); err != nil {
 					errorchan <- nil
 					return
 				}
@@ -1343,3 +1349,75 @@ func (a *Agent) CreateImagePullSecrets(
 
 	return res, nil
 }
+
+// helper that waits for pod to be ready
+func (a *Agent) waitForPod(pod *v1.Pod) (error, bool) {
+	var (
+		w   watch.Interface
+		err error
+		ok  bool
+	)
+	// immediately after creating a pod, the API may return a 404. heuristically 1
+	// second seems to be plenty.
+	watchRetries := 3
+	for i := 0; i < watchRetries; i++ {
+		selector := fields.OneTermEqualSelector("metadata.name", pod.Name).String()
+		w, err = a.Clientset.CoreV1().
+			Pods(pod.Namespace).
+			Watch(context.Background(), metav1.ListOptions{FieldSelector: selector})
+
+		if err == nil {
+			break
+		}
+		time.Sleep(time.Second)
+	}
+	if err != nil {
+		return err, false
+	}
+	defer w.Stop()
+	for {
+		select {
+		case <-time.After(time.Second * 30):
+			return goerrors.New("timed out waiting for pod"), false
+		case <-time.Tick(time.Second):
+			// poll every second in case we already missed the ready event while
+			// creating the listener.
+			pod, err = a.Clientset.CoreV1().
+				Pods(pod.Namespace).
+				Get(context.Background(), pod.Name, metav1.GetOptions{})
+
+			if err != nil && errors.IsNotFound(err) {
+				return IsNotFoundError, false
+			} else if err != nil {
+				return err, false
+			}
+
+			if isExited := isPodExited(pod); isExited || isPodReady(pod) {
+				return nil, isExited
+			}
+		case evt := <-w.ResultChan():
+			pod, ok = evt.Object.(*v1.Pod)
+			if !ok {
+				return fmt.Errorf("unexpected object type: %T", evt.Object), false
+			}
+			if isExited := isPodExited(pod); isExited || isPodReady(pod) {
+				return nil, isExited
+			}
+		}
+	}
+}
+
+func isPodReady(pod *v1.Pod) bool {
+	ready := false
+	conditions := pod.Status.Conditions
+	for i := range conditions {
+		if conditions[i].Type == v1.PodReady {
+			ready = pod.Status.Conditions[i].Status == v1.ConditionTrue
+		}
+	}
+	return ready
+}
+
+func isPodExited(pod *v1.Pod) bool {
+	return pod.Status.Phase == v1.PodSucceeded || pod.Status.Phase == v1.PodFailed
+}

+ 4 - 10
internal/kubernetes/provisioner/resource_stream.go

@@ -4,20 +4,17 @@ import (
 	"context"
 
 	redis "github.com/go-redis/redis/v8"
-	"github.com/gorilla/websocket"
+	"github.com/porter-dev/porter/api/server/shared/websocket"
 )
 
 // ResourceStream performs an XREAD operation on the given stream and outputs it to the given websocket conn.
-func ResourceStream(client *redis.Client, streamName string, conn *websocket.Conn) error {
+func ResourceStream(client *redis.Client, streamName string, rw *websocket.WebsocketSafeReadWriter) error {
 	errorchan := make(chan error)
 
 	go func() {
 		// listens for websocket closing handshake
 		for {
-			_, _, err := conn.ReadMessage()
-
-			if err != nil {
-				defer conn.Close()
+			if _, _, err := rw.ReadMessage(); err != nil {
 				errorchan <- nil
 				return
 			}
@@ -43,10 +40,7 @@ func ResourceStream(client *redis.Client, streamName string, conn *websocket.Con
 			messages := xstream[0].Messages
 			lastID = messages[len(messages)-1].ID
 
-			if writeErr := conn.WriteJSON(messages); writeErr != nil {
-				errorchan <- writeErr
-				return
-			}
+			rw.WriteJSONWithChannel(messages, errorchan)
 		}
 	}()