Procházet zdrojové kódy

Merge pull request #1651 from porter-dev/nafees/goroutine-leaks

[POR-103] Use WaitGroups to get rid of goroutine leaks in websocket stream method
abelanger5 před 4 roky
rodič
revize
5c0bda758e

+ 4 - 0
api/server/shared/websocket/response_writer.go

@@ -46,6 +46,10 @@ func (w *WebsocketSafeReadWriter) ReadMessage() (messageType int, p []byte, err
 	return w.conn.ReadMessage()
 	return w.conn.ReadMessage()
 }
 }
 
 
+func (w *WebsocketSafeReadWriter) Close() error {
+	return w.conn.Close()
+}
+
 type WebsocketResponseWriter struct {
 type WebsocketResponseWriter struct {
 	conn       *websocket.Conn
 	conn       *websocket.Conn
 	safeWriter *WebsocketSafeReadWriter
 	safeWriter *WebsocketSafeReadWriter

+ 54 - 28
internal/kubernetes/agent.go

@@ -11,6 +11,7 @@ import (
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"strings"
 	"strings"
+	"sync"
 	"time"
 	"time"
 
 
 	goerrors "errors"
 	goerrors "errors"
@@ -605,13 +606,21 @@ func (a *Agent) GetPodLogs(namespace string, name string, selectedContainer stri
 		return fmt.Errorf("Cannot open log stream for pod %s: %s", name, err.Error())
 		return fmt.Errorf("Cannot open log stream for pod %s: %s", name, err.Error())
 	}
 	}
 
 
-	defer podLogs.Close()
-
 	r := bufio.NewReader(podLogs)
 	r := bufio.NewReader(podLogs)
 	errorchan := make(chan error)
 	errorchan := make(chan error)
 
 
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	go func() {
+		wg.Wait()
+		close(errorchan)
+	}()
+
 	go func() {
 	go func() {
 		// listens for websocket closing handshake
 		// listens for websocket closing handshake
+		defer wg.Done()
+
 		for {
 		for {
 			if _, _, err := rw.ReadMessage(); err != nil {
 			if _, _, err := rw.ReadMessage(); err != nil {
 				errorchan <- nil
 				errorchan <- nil
@@ -621,11 +630,13 @@ func (a *Agent) GetPodLogs(namespace string, name string, selectedContainer stri
 	}()
 	}()
 
 
 	go func() {
 	go func() {
+		defer wg.Done()
+
 		for {
 		for {
 			bytes, err := r.ReadBytes('\n')
 			bytes, err := r.ReadBytes('\n')
 
 
-			if err == io.EOF {
-				errorchan <- nil
+			if err != nil {
+				errorchan <- err
 				return
 				return
 			}
 			}
 
 
@@ -633,22 +644,15 @@ func (a *Agent) GetPodLogs(namespace string, name string, selectedContainer stri
 				errorchan <- writeErr
 				errorchan <- writeErr
 				return
 				return
 			}
 			}
-
-			select {
-			case <-errorchan:
-				return
-			default:
-			}
 		}
 		}
 	}()
 	}()
 
 
-	for {
-		select {
-		case err = <-errorchan:
-			close(errorchan)
-			return err
-		}
+	for err = range errorchan {
+		rw.Close()
+		podLogs.Close()
 	}
 	}
+
+	return err
 }
 }
 
 
 // GetPodLogs streams real-time logs from a given pod.
 // GetPodLogs streams real-time logs from a given pod.
@@ -836,7 +840,6 @@ func (a *Agent) StreamControllerStatus(kind string, selectors string, rw *websoc
 
 
 		stopper := make(chan struct{})
 		stopper := make(chan struct{})
 		errorchan := make(chan error)
 		errorchan := make(chan error)
-		defer close(stopper)
 
 
 		informer.SetWatchErrorHandler(func(r *cache.Reflector, err error) {
 		informer.SetWatchErrorHandler(func(r *cache.Reflector, err error) {
 			if strings.HasSuffix(err.Error(), ": Unauthorized") {
 			if strings.HasSuffix(err.Error(), ": Unauthorized") {
@@ -871,8 +874,20 @@ func (a *Agent) StreamControllerStatus(kind string, selectors string, rw *websoc
 			},
 			},
 		})
 		})
 
 
+		var wg sync.WaitGroup
+		var err error
+
+		wg.Add(1)
+
+		go func() {
+			wg.Wait()
+			close(errorchan)
+		}()
+
 		go func() {
 		go func() {
 			// listens for websocket closing handshake
 			// listens for websocket closing handshake
+			defer wg.Done()
+
 			for {
 			for {
 				if _, _, err := rw.ReadMessage(); err != nil {
 				if _, _, err := rw.ReadMessage(); err != nil {
 					errorchan <- nil
 					errorchan <- nil
@@ -883,12 +898,12 @@ func (a *Agent) StreamControllerStatus(kind string, selectors string, rw *websoc
 
 
 		go informer.Run(stopper)
 		go informer.Run(stopper)
 
 
-		for {
-			select {
-			case err := <-errorchan:
-				return err
-			}
+		for err = range errorchan {
+			close(stopper)
+			rw.Close()
 		}
 		}
+
+		return err
 	}
 	}
 
 
 	return a.RunWebsocketTask(run)
 	return a.RunWebsocketTask(run)
@@ -980,7 +995,6 @@ func (a *Agent) StreamHelmReleases(namespace string, chartList []string, selecto
 
 
 		stopper := make(chan struct{})
 		stopper := make(chan struct{})
 		errorchan := make(chan error)
 		errorchan := make(chan error)
-		defer close(stopper)
 
 
 		informer.SetWatchErrorHandler(func(r *cache.Reflector, err error) {
 		informer.SetWatchErrorHandler(func(r *cache.Reflector, err error) {
 			if strings.HasSuffix(err.Error(), ": Unauthorized") {
 			if strings.HasSuffix(err.Error(), ": Unauthorized") {
@@ -1069,8 +1083,20 @@ func (a *Agent) StreamHelmReleases(namespace string, chartList []string, selecto
 			},
 			},
 		})
 		})
 
 
+		var wg sync.WaitGroup
+		var err error
+
+		wg.Add(1)
+
+		go func() {
+			wg.Wait()
+			close(errorchan)
+		}()
+
 		go func() {
 		go func() {
 			// listens for websocket closing handshake
 			// listens for websocket closing handshake
+			defer wg.Done()
+
 			for {
 			for {
 				if _, _, err := rw.ReadMessage(); err != nil {
 				if _, _, err := rw.ReadMessage(); err != nil {
 					errorchan <- nil
 					errorchan <- nil
@@ -1081,12 +1107,12 @@ func (a *Agent) StreamHelmReleases(namespace string, chartList []string, selecto
 
 
 		go informer.Run(stopper)
 		go informer.Run(stopper)
 
 
-		for {
-			select {
-			case err := <-errorchan:
-				return err
-			}
+		for err = range errorchan {
+			close(stopper)
+			rw.Close()
 		}
 		}
+
+		return err
 	}
 	}
 
 
 	return a.RunWebsocketTask(run)
 	return a.RunWebsocketTask(run)