Prechádzať zdrojové kódy

pkg/wireguard: ignore changes to peers behind NAT

This commit enables Kilo to ignore changes to the endpoints of peers
that sit behind a NAT gateway. We use the heuristic of a non-zero
persistent keepalive to decide whether the endpoint field should be
ignored. This will allow NATed peers to roam and for every node in the
cluster to have a different value for a peer's endpoint, as is natural
when a peer's connections are NATed.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
Lucas Servén Marín 6 rokov pred
rodič
commit
0d199db009
2 zmenil súbory, kde vykonal 70 pridanie a 20 odobranie
  1. 36 1
      pkg/mesh/mesh.go
  2. 34 19
      pkg/wireguard/conf.go

+ 36 - 1
pkg/mesh/mesh.go

@@ -659,7 +659,7 @@ func (m *Mesh) applyTopology() {
 		}
 		// Setting the WireGuard configuration interrupts existing connections
 		// so only set the configuration if it has changed.
-		equal := conf.Equal(wireguard.Parse(oldConf))
+		equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT)
 		if !equal {
 			level.Info(m.logger).Log("msg", "WireGuard configurations are different")
 			if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
@@ -856,6 +856,41 @@ func peersAreEqual(a, b *Peer) bool {
 	return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
 }
 
+// Basic nil checks and checking the lengths of the allowed IPs is
+// done by the WireGuard package.
+func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool {
+	for j := range a.AllowedIPs {
+		if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
+			return false
+		}
+	}
+	if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) {
+		return false
+	}
+	// If a persistent keepalive is set, then the peer is behind NAT
+	// and we want to ignore changes in endpoints, since it may roam.
+	if a.PersistentKeepalive != 0 {
+		return true
+	}
+	if (a.Endpoint == nil) != (b.Endpoint == nil) {
+		return false
+	}
+	if a.Endpoint != nil {
+		if a.Endpoint.Port != b.Endpoint.Port {
+			return false
+		}
+		// IPs take priority, so check them first.
+		if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
+			return false
+		}
+		// Only check the DNS name if the IP is empty.
+		if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
+			return false
+		}
+	}
+	return true
+}
+
 func ipNetsEqual(a, b *net.IPNet) bool {
 	if a == nil && b == nil {
 		return true

+ 34 - 19
pkg/wireguard/conf.go

@@ -275,6 +275,12 @@ func (c *Conf) Bytes() ([]byte, error) {
 
 // Equal checks if two WireGuard configurations are equivalent.
 func (c *Conf) Equal(b *Conf) bool {
+	return c.EqualWithPeerCheck(b, strictPeerCheck)
+}
+
+// EqualWithPeerCheck checks if two WireGuard configurations are equivalent
+// when their peers are compared using the given peer comparison func.
+func (c *Conf) EqualWithPeerCheck(b *Conf, pc PeerCheck) bool {
 	if (c.Interface == nil) != (b.Interface == nil) {
 		return false
 	}
@@ -288,38 +294,47 @@ func (c *Conf) Equal(b *Conf) bool {
 	}
 	sortPeers(c.Peers)
 	sortPeers(b.Peers)
+	var ok bool
 	for i := range c.Peers {
 		if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
 			return false
 		}
 		sortCIDRs(c.Peers[i].AllowedIPs)
 		sortCIDRs(b.Peers[i].AllowedIPs)
-		for j := range c.Peers[i].AllowedIPs {
-			if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
-				return false
-			}
+		if ok = pc(c.Peers[i], b.Peers[i]); !ok {
+			return false
+		}
+	}
+	return true
+
+}
+
+// PeerCheck is a function that compares two peers.
+type PeerCheck func(a, b *Peer) bool
+
+func strictPeerCheck(a, b *Peer) bool {
+	for j := range a.AllowedIPs {
+		if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
+			return false
 		}
-		if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
+	}
+	if (a.Endpoint == nil) != (b.Endpoint == nil) {
+		return false
+	}
+	if a.Endpoint != nil {
+		if a.Endpoint.Port != b.Endpoint.Port {
 			return false
 		}
-		if c.Peers[i].Endpoint != nil {
-			if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
-				return false
-			}
-			// IPs take priority, so check them first.
-			if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
-				return false
-			}
-			// Only check the DNS name if the IP is empty.
-			if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
-				return false
-			}
+		// IPs take priority, so check them first.
+		if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
+			return false
 		}
-		if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
+		// Only check the DNS name if the IP is empty.
+		if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
 			return false
 		}
 	}
-	return true
+	return a.PersistentKeepalive == b.PersistentKeepalive && bytes.Equal(a.PublicKey, b.PublicKey)
 }
 
 func sortPeers(peers []*Peer) {