소스 검색

Remove context.TODO() (#310)

Remove almost all (except the ones created by informer-gen)
context.TODOs.

Signed-off-by: leonnicolas <leonloechner@gmx.de>
leonnicolas 4 년 전
부모
커밋
0820a9d32f
5개의 변경된 파일58개의 추가작업 그리고 53개의 파일을 삭제
  1. 4 2
      cmd/kg/main.go
  2. 3 3
      cmd/kgctl/main.go
  3. 14 14
      pkg/k8s/backend.go
  4. 7 6
      pkg/mesh/backend.go
  5. 30 28
      pkg/mesh/mesh.go

+ 4 - 2
cmd/kg/main.go

@@ -15,6 +15,7 @@
 package main
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"net"
@@ -275,15 +276,16 @@ func runRoot(_ *cobra.Command, _ []string) error {
 	}
 
 	{
+		ctx, cancel := context.WithCancel(context.Background())
 		// Start the mesh.
 		g.Add(func() error {
 			logger.Log("msg", fmt.Sprintf("Starting Kilo network mesh '%v'.", version.Version))
-			if err := m.Run(); err != nil {
+			if err := m.Run(ctx); err != nil {
 				return fmt.Errorf("error: Kilo exited unexpectedly: %v", err)
 			}
 			return nil
 		}, func(error) {
-			m.Stop()
+			cancel()
 		})
 	}
 

+ 3 - 3
cmd/kgctl/main.go

@@ -71,7 +71,7 @@ var (
 	topologyLabel string
 )
 
-func runRoot(_ *cobra.Command, _ []string) error {
+func runRoot(c *cobra.Command, _ []string) error {
 	if opts.port < 1 || opts.port > 1<<16-1 {
 		return fmt.Errorf("invalid port: port mus be in range [%d:%d], but got %d", 1, 1<<16-1, opts.port)
 	}
@@ -99,11 +99,11 @@ func runRoot(_ *cobra.Command, _ []string) error {
 		return fmt.Errorf("backend %s unknown; posible values are: %s", backend, availableBackends)
 	}
 
-	if err := opts.backend.Nodes().Init(make(chan struct{})); err != nil {
+	if err := opts.backend.Nodes().Init(c.Context()); err != nil {
 		return fmt.Errorf("failed to initialize node backend: %w", err)
 	}
 
-	if err := opts.backend.Peers().Init(make(chan struct{})); err != nil {
+	if err := opts.backend.Peers().Init(c.Context()); err != nil {
 		return fmt.Errorf("failed to initialize peer backend: %w", err)
 	}
 	return nil

+ 14 - 14
pkg/k8s/backend.go

@@ -128,7 +128,7 @@ func New(c kubernetes.Interface, kc kiloclient.Interface, ec apiextensions.Inter
 }
 
 // CleanUp removes configuration applied to the backend.
-func (nb *nodeBackend) CleanUp(name string) error {
+func (nb *nodeBackend) CleanUp(ctx context.Context, name string) error {
 	patch := []byte("[" + strings.Join([]string{
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(endpointAnnotationKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(internalIPAnnotationKey, "/", jsonPatchSlash, 1))),
@@ -138,7 +138,7 @@ func (nb *nodeBackend) CleanUp(name string) error {
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(discoveredEndpointsKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(granularityKey, "/", jsonPatchSlash, 1))),
 	}, ",") + "]")
-	if _, err := nb.client.CoreV1().Nodes().Patch(context.TODO(), name, types.JSONPatchType, patch, metav1.PatchOptions{}); err != nil {
+	if _, err := nb.client.CoreV1().Nodes().Patch(ctx, name, types.JSONPatchType, patch, metav1.PatchOptions{}); err != nil {
 		return fmt.Errorf("failed to patch node: %v", err)
 	}
 	return nil
@@ -155,9 +155,9 @@ func (nb *nodeBackend) Get(name string) (*mesh.Node, error) {
 
 // Init initializes the backend; for this backend that means
 // syncing the informer cache.
-func (nb *nodeBackend) Init(stop <-chan struct{}) error {
-	go nb.informer.Run(stop)
-	if ok := cache.WaitForCacheSync(stop, func() bool {
+func (nb *nodeBackend) Init(ctx context.Context) error {
+	go nb.informer.Run(ctx.Done())
+	if ok := cache.WaitForCacheSync(ctx.Done(), func() bool {
 		return nb.informer.HasSynced()
 	}); !ok {
 		return errors.New("failed to sync node cache")
@@ -212,7 +212,7 @@ func (nb *nodeBackend) List() ([]*mesh.Node, error) {
 }
 
 // Set sets the fields of a node.
-func (nb *nodeBackend) Set(name string, node *mesh.Node) error {
+func (nb *nodeBackend) Set(ctx context.Context, name string, node *mesh.Node) error {
 	old, err := nb.lister.Get(name)
 	if err != nil {
 		return fmt.Errorf("failed to find node: %v", err)
@@ -253,7 +253,7 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error {
 	if err != nil {
 		return fmt.Errorf("failed to create patch for node %q: %v", n.Name, err)
 	}
-	if _, err = nb.client.CoreV1().Nodes().Patch(context.TODO(), name, types.StrategicMergePatchType, patch, metav1.PatchOptions{}); err != nil {
+	if _, err = nb.client.CoreV1().Nodes().Patch(ctx, name, types.StrategicMergePatchType, patch, metav1.PatchOptions{}); err != nil {
 		return fmt.Errorf("failed to patch node: %v", err)
 	}
 	return nil
@@ -431,7 +431,7 @@ func translatePeer(peer *v1alpha1.Peer) *mesh.Peer {
 }
 
 // CleanUp removes configuration applied to the backend.
-func (pb *peerBackend) CleanUp(name string) error {
+func (pb *peerBackend) CleanUp(_ context.Context, _ string) error {
 	return nil
 }
 
@@ -446,14 +446,14 @@ func (pb *peerBackend) Get(name string) (*mesh.Peer, error) {
 
 // Init initializes the backend; for this backend that means
 // syncing the informer cache.
-func (pb *peerBackend) Init(stop <-chan struct{}) error {
+func (pb *peerBackend) Init(ctx context.Context) error {
 	// Check the presents of the CRD peers.kilo.squat.ai.
-	if _, err := pb.extensionsClient.ApiextensionsV1().CustomResourceDefinitions().Get(context.TODO(), strings.Join([]string{v1alpha1.PeerPlural, v1alpha1.GroupName}, "."), metav1.GetOptions{}); err != nil {
+	if _, err := pb.extensionsClient.ApiextensionsV1().CustomResourceDefinitions().Get(ctx, strings.Join([]string{v1alpha1.PeerPlural, v1alpha1.GroupName}, "."), metav1.GetOptions{}); err != nil {
 		return fmt.Errorf("CRD is not present: %v", err)
 	}
 
-	go pb.informer.Run(stop)
-	if ok := cache.WaitForCacheSync(stop, func() bool {
+	go pb.informer.Run(ctx.Done())
+	if ok := cache.WaitForCacheSync(ctx.Done(), func() bool {
 		return pb.informer.HasSynced()
 	}); !ok {
 		return errors.New("failed to sync peer cache")
@@ -512,7 +512,7 @@ func (pb *peerBackend) List() ([]*mesh.Peer, error) {
 }
 
 // Set sets the fields of a peer.
-func (pb *peerBackend) Set(name string, peer *mesh.Peer) error {
+func (pb *peerBackend) Set(ctx context.Context, name string, peer *mesh.Peer) error {
 	old, err := pb.lister.Get(name)
 	if err != nil {
 		return fmt.Errorf("failed to find peer: %v", err)
@@ -542,7 +542,7 @@ func (pb *peerBackend) Set(name string, peer *mesh.Peer) error {
 		p.Spec.PresharedKey = peer.PresharedKey.String()
 	}
 	p.Spec.PublicKey = peer.PublicKey.String()
-	if _, err = pb.client.KiloV1alpha1().Peers().Update(context.TODO(), p, metav1.UpdateOptions{}); err != nil {
+	if _, err = pb.client.KiloV1alpha1().Peers().Update(ctx, p, metav1.UpdateOptions{}); err != nil {
 		return fmt.Errorf("failed to update peer: %v", err)
 	}
 	return nil

+ 7 - 6
pkg/mesh/backend.go

@@ -15,6 +15,7 @@
 package mesh
 
 import (
+	"context"
 	"net"
 	"time"
 
@@ -146,11 +147,11 @@ type Backend interface {
 // clean up any changes applied to the backend,
 // and watch for changes to nodes.
 type NodeBackend interface {
-	CleanUp(string) error
+	CleanUp(context.Context, string) error
 	Get(string) (*Node, error)
-	Init(<-chan struct{}) error
+	Init(context.Context) error
 	List() ([]*Node, error)
-	Set(string, *Node) error
+	Set(context.Context, string, *Node) error
 	Watch() <-chan *NodeEvent
 }
 
@@ -160,10 +161,10 @@ type NodeBackend interface {
 // clean up any changes applied to the backend,
 // and watch for changes to peers.
 type PeerBackend interface {
-	CleanUp(string) error
+	CleanUp(context.Context, string) error
 	Get(string) (*Peer, error)
-	Init(<-chan struct{}) error
+	Init(context.Context) error
 	List() ([]*Peer, error)
-	Set(string, *Peer) error
+	Set(context.Context, string, *Peer) error
 	Watch() <-chan *PeerEvent
 }

+ 30 - 28
pkg/mesh/mesh.go

@@ -19,6 +19,7 @@ package mesh
 
 import (
 	"bytes"
+	"context"
 	"fmt"
 	"io/ioutil"
 	"net"
@@ -69,7 +70,6 @@ type Mesh struct {
 	pub                 wgtypes.Key
 	resyncPeriod        time.Duration
 	iptablesForwardRule bool
-	stop                chan struct{}
 	subnet              *net.IPNet
 	table               *route.Table
 	wireGuardIP         *net.IPNet
@@ -180,7 +180,6 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
 		resyncPeriod:        resyncPeriod,
 		iptablesForwardRule: iptablesForwardRule,
 		local:               local,
-		stop:                make(chan struct{}),
 		subnet:              subnet,
 		table:               route.NewTable(),
 		errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
@@ -208,8 +207,8 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
 }
 
 // Run starts the mesh.
-func (m *Mesh) Run() error {
-	if err := m.Nodes().Init(m.stop); err != nil {
+func (m *Mesh) Run(ctx context.Context) error {
+	if err := m.Nodes().Init(ctx); err != nil {
 		return fmt.Errorf("failed to initialize node backend: %v", err)
 	}
 	// Try to set the CNI config quickly.
@@ -221,14 +220,14 @@ func (m *Mesh) Run() error {
 			level.Warn(m.logger).Log("error", fmt.Errorf("failed to get node %q: %v", m.hostname, err))
 		}
 	}
-	if err := m.Peers().Init(m.stop); err != nil {
+	if err := m.Peers().Init(ctx); err != nil {
 		return fmt.Errorf("failed to initialize peer backend: %v", err)
 	}
-	ipTablesErrors, err := m.ipTables.Run(m.stop)
+	ipTablesErrors, err := m.ipTables.Run(ctx.Done())
 	if err != nil {
 		return fmt.Errorf("failed to watch for IP tables updates: %v", err)
 	}
-	routeErrors, err := m.table.Run(m.stop)
+	routeErrors, err := m.table.Run(ctx.Done())
 	if err != nil {
 		return fmt.Errorf("failed to watch for route table updates: %v", err)
 	}
@@ -238,7 +237,7 @@ func (m *Mesh) Run() error {
 			select {
 			case err = <-ipTablesErrors:
 			case err = <-routeErrors:
-			case <-m.stop:
+			case <-ctx.Done():
 				return
 			}
 			if err != nil {
@@ -257,11 +256,11 @@ func (m *Mesh) Run() error {
 	for {
 		select {
 		case ne = <-nw:
-			m.syncNodes(ne)
+			m.syncNodes(ctx, ne)
 		case pe = <-pw:
 			m.syncPeers(pe)
 		case <-checkIn.C:
-			m.checkIn()
+			m.checkIn(ctx)
 			checkIn.Reset(checkInPeriod)
 		case <-resync.C:
 			if m.cni {
@@ -269,18 +268,18 @@ func (m *Mesh) Run() error {
 			}
 			m.applyTopology()
 			resync.Reset(m.resyncPeriod)
-		case <-m.stop:
+		case <-ctx.Done():
 			return nil
 		}
 	}
 }
 
-func (m *Mesh) syncNodes(e *NodeEvent) {
+func (m *Mesh) syncNodes(ctx context.Context, e *NodeEvent) {
 	logger := log.With(m.logger, "event", e.Type)
 	level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
 	if isSelf(m.hostname, e.Node) {
 		level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
-		m.handleLocal(e.Node)
+		m.handleLocal(ctx, e.Node)
 		return
 	}
 	var diff bool
@@ -348,7 +347,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
 
 // checkIn will try to update the local node's LastSeen timestamp
 // in the backend.
-func (m *Mesh) checkIn() {
+func (m *Mesh) checkIn(ctx context.Context) {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 	n := m.nodes[m.hostname]
@@ -358,7 +357,7 @@ func (m *Mesh) checkIn() {
 	}
 	oldTime := n.LastSeen
 	n.LastSeen = time.Now().Unix()
-	if err := m.Nodes().Set(m.hostname, n); err != nil {
+	if err := m.Nodes().Set(ctx, m.hostname, n); err != nil {
 		level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
 		m.errorCounter.WithLabelValues("checkin").Inc()
 		// Revert time.
@@ -368,7 +367,7 @@ func (m *Mesh) checkIn() {
 	level.Debug(m.logger).Log("msg", "successfully checked in local node in backend")
 }
 
-func (m *Mesh) handleLocal(n *Node) {
+func (m *Mesh) handleLocal(ctx context.Context, n *Node) {
 	// Allow the IPs to be overridden.
 	if !n.Endpoint.Ready() {
 		e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
@@ -399,7 +398,7 @@ func (m *Mesh) handleLocal(n *Node) {
 	}
 	if !nodesAreEqual(n, local) {
 		level.Debug(m.logger).Log("msg", "local node differs from backend")
-		if err := m.Nodes().Set(m.hostname, local); err != nil {
+		if err := m.Nodes().Set(ctx, m.hostname, local); err != nil {
 			level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
 			m.errorCounter.WithLabelValues("local").Inc()
 			return
@@ -584,11 +583,6 @@ func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
 	)
 }
 
-// Stop stops the mesh.
-func (m *Mesh) Stop() {
-	close(m.stop)
-}
-
 func (m *Mesh) cleanUp() {
 	if err := m.ipTables.CleanUp(); err != nil {
 		level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
@@ -604,13 +598,21 @@ func (m *Mesh) cleanUp() {
 			m.errorCounter.WithLabelValues("cleanUp").Inc()
 		}
 	}
-	if err := m.Nodes().CleanUp(m.hostname); err != nil {
-		level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
-		m.errorCounter.WithLabelValues("cleanUp").Inc()
+	{
+		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+		defer cancel()
+		if err := m.Nodes().CleanUp(ctx, m.hostname); err != nil {
+			level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
+			m.errorCounter.WithLabelValues("cleanUp").Inc()
+		}
 	}
-	if err := m.Peers().CleanUp(m.hostname); err != nil {
-		level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
-		m.errorCounter.WithLabelValues("cleanUp").Inc()
+	{
+		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+		defer cancel()
+		if err := m.Peers().CleanUp(ctx, m.hostname); err != nil {
+			level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
+			m.errorCounter.WithLabelValues("cleanUp").Inc()
+		}
 	}
 	if err := m.enc.CleanUp(); err != nil {
 		level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulator: %v", err))