|
|
@@ -19,6 +19,7 @@ package mesh
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "context"
|
|
|
"fmt"
|
|
|
"io/ioutil"
|
|
|
"net"
|
|
|
@@ -69,7 +70,6 @@ type Mesh struct {
|
|
|
pub wgtypes.Key
|
|
|
resyncPeriod time.Duration
|
|
|
iptablesForwardRule bool
|
|
|
- stop chan struct{}
|
|
|
subnet *net.IPNet
|
|
|
table *route.Table
|
|
|
wireGuardIP *net.IPNet
|
|
|
@@ -180,7 +180,6 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
|
|
|
resyncPeriod: resyncPeriod,
|
|
|
iptablesForwardRule: iptablesForwardRule,
|
|
|
local: local,
|
|
|
- stop: make(chan struct{}),
|
|
|
subnet: subnet,
|
|
|
table: route.NewTable(),
|
|
|
errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
|
|
|
@@ -208,8 +207,8 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
|
|
|
}
|
|
|
|
|
|
// Run starts the mesh.
|
|
|
-func (m *Mesh) Run() error {
|
|
|
- if err := m.Nodes().Init(m.stop); err != nil {
|
|
|
+func (m *Mesh) Run(ctx context.Context) error {
|
|
|
+ if err := m.Nodes().Init(ctx); err != nil {
|
|
|
return fmt.Errorf("failed to initialize node backend: %v", err)
|
|
|
}
|
|
|
// Try to set the CNI config quickly.
|
|
|
@@ -221,14 +220,14 @@ func (m *Mesh) Run() error {
|
|
|
level.Warn(m.logger).Log("error", fmt.Errorf("failed to get node %q: %v", m.hostname, err))
|
|
|
}
|
|
|
}
|
|
|
- if err := m.Peers().Init(m.stop); err != nil {
|
|
|
+ if err := m.Peers().Init(ctx); err != nil {
|
|
|
return fmt.Errorf("failed to initialize peer backend: %v", err)
|
|
|
}
|
|
|
- ipTablesErrors, err := m.ipTables.Run(m.stop)
|
|
|
+ ipTablesErrors, err := m.ipTables.Run(ctx.Done())
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to watch for IP tables updates: %v", err)
|
|
|
}
|
|
|
- routeErrors, err := m.table.Run(m.stop)
|
|
|
+ routeErrors, err := m.table.Run(ctx.Done())
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to watch for route table updates: %v", err)
|
|
|
}
|
|
|
@@ -238,7 +237,7 @@ func (m *Mesh) Run() error {
|
|
|
select {
|
|
|
case err = <-ipTablesErrors:
|
|
|
case err = <-routeErrors:
|
|
|
- case <-m.stop:
|
|
|
+ case <-ctx.Done():
|
|
|
return
|
|
|
}
|
|
|
if err != nil {
|
|
|
@@ -257,11 +256,11 @@ func (m *Mesh) Run() error {
|
|
|
for {
|
|
|
select {
|
|
|
case ne = <-nw:
|
|
|
- m.syncNodes(ne)
|
|
|
+ m.syncNodes(ctx, ne)
|
|
|
case pe = <-pw:
|
|
|
m.syncPeers(pe)
|
|
|
case <-checkIn.C:
|
|
|
- m.checkIn()
|
|
|
+ m.checkIn(ctx)
|
|
|
checkIn.Reset(checkInPeriod)
|
|
|
case <-resync.C:
|
|
|
if m.cni {
|
|
|
@@ -269,18 +268,18 @@ func (m *Mesh) Run() error {
|
|
|
}
|
|
|
m.applyTopology()
|
|
|
resync.Reset(m.resyncPeriod)
|
|
|
- case <-m.stop:
|
|
|
+ case <-ctx.Done():
|
|
|
return nil
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (m *Mesh) syncNodes(e *NodeEvent) {
|
|
|
+func (m *Mesh) syncNodes(ctx context.Context, e *NodeEvent) {
|
|
|
logger := log.With(m.logger, "event", e.Type)
|
|
|
level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
|
|
|
if isSelf(m.hostname, e.Node) {
|
|
|
level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
|
|
|
- m.handleLocal(e.Node)
|
|
|
+ m.handleLocal(ctx, e.Node)
|
|
|
return
|
|
|
}
|
|
|
var diff bool
|
|
|
@@ -348,7 +347,7 @@ func (m *Mesh) syncPeers(e *PeerEvent) {
|
|
|
|
|
|
// checkIn will try to update the local node's LastSeen timestamp
|
|
|
// in the backend.
|
|
|
-func (m *Mesh) checkIn() {
|
|
|
+func (m *Mesh) checkIn(ctx context.Context) {
|
|
|
m.mu.Lock()
|
|
|
defer m.mu.Unlock()
|
|
|
n := m.nodes[m.hostname]
|
|
|
@@ -358,7 +357,7 @@ func (m *Mesh) checkIn() {
|
|
|
}
|
|
|
oldTime := n.LastSeen
|
|
|
n.LastSeen = time.Now().Unix()
|
|
|
- if err := m.Nodes().Set(m.hostname, n); err != nil {
|
|
|
+ if err := m.Nodes().Set(ctx, m.hostname, n); err != nil {
|
|
|
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
|
|
|
m.errorCounter.WithLabelValues("checkin").Inc()
|
|
|
// Revert time.
|
|
|
@@ -368,7 +367,7 @@ func (m *Mesh) checkIn() {
|
|
|
level.Debug(m.logger).Log("msg", "successfully checked in local node in backend")
|
|
|
}
|
|
|
|
|
|
-func (m *Mesh) handleLocal(n *Node) {
|
|
|
+func (m *Mesh) handleLocal(ctx context.Context, n *Node) {
|
|
|
// Allow the IPs to be overridden.
|
|
|
if !n.Endpoint.Ready() {
|
|
|
e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
|
|
|
@@ -399,7 +398,7 @@ func (m *Mesh) handleLocal(n *Node) {
|
|
|
}
|
|
|
if !nodesAreEqual(n, local) {
|
|
|
level.Debug(m.logger).Log("msg", "local node differs from backend")
|
|
|
- if err := m.Nodes().Set(m.hostname, local); err != nil {
|
|
|
+ if err := m.Nodes().Set(ctx, m.hostname, local); err != nil {
|
|
|
level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
|
|
|
m.errorCounter.WithLabelValues("local").Inc()
|
|
|
return
|
|
|
@@ -584,11 +583,6 @@ func (m *Mesh) RegisterMetrics(r prometheus.Registerer) {
|
|
|
)
|
|
|
}
|
|
|
|
|
|
-// Stop stops the mesh.
|
|
|
-func (m *Mesh) Stop() {
|
|
|
- close(m.stop)
|
|
|
-}
|
|
|
-
|
|
|
func (m *Mesh) cleanUp() {
|
|
|
if err := m.ipTables.CleanUp(); err != nil {
|
|
|
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
|
|
|
@@ -604,13 +598,21 @@ func (m *Mesh) cleanUp() {
|
|
|
m.errorCounter.WithLabelValues("cleanUp").Inc()
|
|
|
}
|
|
|
}
|
|
|
- if err := m.Nodes().CleanUp(m.hostname); err != nil {
|
|
|
- level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
|
|
|
- m.errorCounter.WithLabelValues("cleanUp").Inc()
|
|
|
+ {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+ if err := m.Nodes().CleanUp(ctx, m.hostname); err != nil {
|
|
|
+ level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
|
|
|
+ m.errorCounter.WithLabelValues("cleanUp").Inc()
|
|
|
+ }
|
|
|
}
|
|
|
- if err := m.Peers().CleanUp(m.hostname); err != nil {
|
|
|
- level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
|
|
|
- m.errorCounter.WithLabelValues("cleanUp").Inc()
|
|
|
+ {
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
+ defer cancel()
|
|
|
+ if err := m.Peers().CleanUp(ctx, m.hostname); err != nil {
|
|
|
+ level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
|
|
|
+ m.errorCounter.WithLabelValues("cleanUp").Inc()
|
|
|
+ }
|
|
|
}
|
|
|
if err := m.enc.CleanUp(); err != nil {
|
|
|
level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulator: %v", err))
|