2
0
Эх сурвалжийг харах

Nat to nat (#146)

* wireguard: export an Endpoint comparison method

* Record discovered endpoints in node

* Synchronize DiscoveredEndpoints in k8s backend

* Add discoveredEndpointsAreEqual

* Handle discovered Endpoints in topology to enable NAT 2 NAT

* Refactor to use Endpoint.Equal

Compare IP first by default and compare DNS name first when we know the Endpoint was resolved.

* Drop the shallow copies of nodes and peers

Now that updateNATEndpoints was updated to discoverNATEndpoints and that
the endpoints are overridden by topology instead of mutating the nodes and
peers object, we can safely drop this copy.
Julien Viard de Galbert 5 жил өмнө
parent
commit
2ac000c68a

+ 21 - 1
pkg/k8s/backend.go

@@ -59,6 +59,7 @@ const (
 	locationAnnotationKey        = "kilo.squat.ai/location"
 	locationAnnotationKey        = "kilo.squat.ai/location"
 	persistentKeepaliveKey       = "kilo.squat.ai/persistent-keepalive"
 	persistentKeepaliveKey       = "kilo.squat.ai/persistent-keepalive"
 	wireGuardIPAnnotationKey     = "kilo.squat.ai/wireguard-ip"
 	wireGuardIPAnnotationKey     = "kilo.squat.ai/wireguard-ip"
+	discoveredEndpointsKey       = "kilo.squat.ai/discovered-endpoints"
 	// RegionLabelKey is the key for the well-known Kubernetes topology region label.
 	// RegionLabelKey is the key for the well-known Kubernetes topology region label.
 	RegionLabelKey  = "topology.kubernetes.io/region"
 	RegionLabelKey  = "topology.kubernetes.io/region"
 	jsonPatchSlash  = "~1"
 	jsonPatchSlash  = "~1"
@@ -127,6 +128,7 @@ func (nb *nodeBackend) CleanUp(name string) error {
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(keyAnnotationKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(keyAnnotationKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(lastSeenAnnotationKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(lastSeenAnnotationKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(wireGuardIPAnnotationKey, "/", jsonPatchSlash, 1))),
 		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(wireGuardIPAnnotationKey, "/", jsonPatchSlash, 1))),
+		fmt.Sprintf(jsonRemovePatch, path.Join("/metadata", "annotations", strings.Replace(discoveredEndpointsKey, "/", jsonPatchSlash, 1))),
 	}, ",") + "]")
 	}, ",") + "]")
 	if _, err := nb.client.CoreV1().Nodes().Patch(name, types.JSONPatchType, patch); err != nil {
 	if _, err := nb.client.CoreV1().Nodes().Patch(name, types.JSONPatchType, patch); err != nil {
 		return fmt.Errorf("failed to patch node: %v", err)
 		return fmt.Errorf("failed to patch node: %v", err)
@@ -221,6 +223,15 @@ func (nb *nodeBackend) Set(name string, node *mesh.Node) error {
 	} else {
 	} else {
 		n.ObjectMeta.Annotations[wireGuardIPAnnotationKey] = node.WireGuardIP.String()
 		n.ObjectMeta.Annotations[wireGuardIPAnnotationKey] = node.WireGuardIP.String()
 	}
 	}
+	if node.DiscoveredEndpoints == nil {
+		n.ObjectMeta.Annotations[discoveredEndpointsKey] = ""
+	} else {
+		discoveredEndpoints, err := json.Marshal(node.DiscoveredEndpoints)
+		if err != nil {
+			return err
+		}
+		n.ObjectMeta.Annotations[discoveredEndpointsKey] = string(discoveredEndpoints)
+	}
 	oldData, err := json.Marshal(old)
 	oldData, err := json.Marshal(old)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -294,6 +305,14 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
 			lastSeen = 0
 			lastSeen = 0
 		}
 		}
 	}
 	}
+	var discoveredEndpoints map[string]*wireguard.Endpoint
+	if de, ok := node.ObjectMeta.Annotations[discoveredEndpointsKey]; ok {
+		err := json.Unmarshal([]byte(de), &discoveredEndpoints)
+		if err != nil {
+			discoveredEndpoints = nil
+		}
+	}
+
 	return &mesh.Node{
 	return &mesh.Node{
 		// Endpoint and InternalIP should only ever fail to parse if the
 		// Endpoint and InternalIP should only ever fail to parse if the
 		// remote node's agent has not yet set its IP address;
 		// remote node's agent has not yet set its IP address;
@@ -314,7 +333,8 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
 		// WireGuardIP can fail to parse if the node is not a leader or if
 		// WireGuardIP can fail to parse if the node is not a leader or if
 		// the node's agent has not yet reconciled. In either case, the IP
 		// the node's agent has not yet reconciled. In either case, the IP
 		// will parse as nil.
 		// will parse as nil.
-		WireGuardIP: normalizeIP(node.ObjectMeta.Annotations[wireGuardIPAnnotationKey]),
+		WireGuardIP:         normalizeIP(node.ObjectMeta.Annotations[wireGuardIPAnnotationKey]),
+		DiscoveredEndpoints: discoveredEndpoints,
 	}
 	}
 }
 }
 
 

+ 1 - 0
pkg/mesh/backend.go

@@ -66,6 +66,7 @@ type Node struct {
 	PersistentKeepalive int
 	PersistentKeepalive int
 	Subnet              *net.IPNet
 	Subnet              *net.IPNet
 	WireGuardIP         *net.IPNet
 	WireGuardIP         *net.IPNet
+	DiscoveredEndpoints map[string]*wireguard.Endpoint
 }
 }
 
 
 // Ready indicates whether or not the node is ready.
 // Ready indicates whether or not the node is ready.

+ 46 - 41
pkg/mesh/mesh.go

@@ -389,6 +389,7 @@ func (m *Mesh) handleLocal(n *Node) {
 		PersistentKeepalive: n.PersistentKeepalive,
 		PersistentKeepalive: n.PersistentKeepalive,
 		Subnet:              n.Subnet,
 		Subnet:              n.Subnet,
 		WireGuardIP:         m.wireGuardIP,
 		WireGuardIP:         m.wireGuardIP,
+		DiscoveredEndpoints: n.DiscoveredEndpoints,
 	}
 	}
 	if !nodesAreEqual(n, local) {
 	if !nodesAreEqual(n, local) {
 		level.Debug(m.logger).Log("msg", "local node differs from backend")
 		level.Debug(m.logger).Log("msg", "local node differs from backend")
@@ -431,9 +432,8 @@ func (m *Mesh) applyTopology() {
 		if !m.nodes[k].Ready() {
 		if !m.nodes[k].Ready() {
 			continue
 			continue
 		}
 		}
-		// Make a shallow copy of the node.
-		node := *m.nodes[k]
-		nodes[k] = &node
+		// Make it point to the node without copy.
+		nodes[k] = m.nodes[k]
 		readyNodes++
 		readyNodes++
 	}
 	}
 	// Ensure only ready nodes are considered.
 	// Ensure only ready nodes are considered.
@@ -443,9 +443,8 @@ func (m *Mesh) applyTopology() {
 		if !m.peers[k].Ready() {
 		if !m.peers[k].Ready() {
 			continue
 			continue
 		}
 		}
-		// Make a shallow copy of the peer.
-		peer := *m.peers[k]
-		peers[k] = &peer
+		// Make it point the peer without copy.
+		peers[k] = m.peers[k]
 		readyPeers++
 		readyPeers++
 	}
 	}
 	m.nodesGuage.Set(readyNodes)
 	m.nodesGuage.Set(readyNodes)
@@ -469,7 +468,8 @@ func (m *Mesh) applyTopology() {
 		return
 		return
 	}
 	}
 	oldConf := wireguard.Parse(oldConfRaw)
 	oldConf := wireguard.Parse(oldConfRaw)
-	updateNATEndpoints(nodes, peers, oldConf)
+	natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
+	nodes[m.hostname].DiscoveredEndpoints = natEndpoints
 	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
 	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
 	if err != nil {
 	if err != nil {
 		level.Error(m.logger).Log("error", err)
 		level.Error(m.logger).Log("error", err)
@@ -676,26 +676,15 @@ func nodesAreEqual(a, b *Node) bool {
 	if a == b {
 	if a == b {
 		return true
 		return true
 	}
 	}
-	if !(a.Endpoint != nil) == (b.Endpoint != nil) {
+	// Check the DNS name first since this package
+	// is doing the DNS resolution.
+	if !a.Endpoint.Equal(b.Endpoint, true) {
 		return false
 		return false
 	}
 	}
-	if a.Endpoint != nil {
-		if a.Endpoint.Port != b.Endpoint.Port {
-			return false
-		}
-		// Check the DNS name first since this package
-		// is doing the DNS resolution.
-		if a.Endpoint.DNS != b.Endpoint.DNS {
-			return false
-		}
-		if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
-			return false
-		}
-	}
 	// Ignore LastSeen when comparing equality we want to check if the nodes are
 	// Ignore LastSeen when comparing equality we want to check if the nodes are
 	// equivalent. However, we do want to check if LastSeen has transitioned
 	// equivalent. However, we do want to check if LastSeen has transitioned
 	// between valid and invalid.
 	// between valid and invalid.
-	return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive
+	return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints)
 }
 }
 
 
 func peersAreEqual(a, b *Peer) bool {
 func peersAreEqual(a, b *Peer) bool {
@@ -705,22 +694,11 @@ func peersAreEqual(a, b *Peer) bool {
 	if a == b {
 	if a == b {
 		return true
 		return true
 	}
 	}
-	if !(a.Endpoint != nil) == (b.Endpoint != nil) {
+	// Check the DNS name first since this package
+	// is doing the DNS resolution.
+	if !a.Endpoint.Equal(b.Endpoint, true) {
 		return false
 		return false
 	}
 	}
-	if a.Endpoint != nil {
-		if a.Endpoint.Port != b.Endpoint.Port {
-			return false
-		}
-		// Check the DNS name first since this package
-		// is doing the DNS resolution.
-		if a.Endpoint.DNS != b.Endpoint.DNS {
-			return false
-		}
-		if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
-			return false
-		}
-	}
 	if len(a.AllowedIPs) != len(b.AllowedIPs) {
 	if len(a.AllowedIPs) != len(b.AllowedIPs) {
 		return false
 		return false
 	}
 	}
@@ -764,6 +742,24 @@ func subnetsEqual(a, b *net.IPNet) bool {
 	return true
 	return true
 }
 }
 
 
+func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
+	if a == nil && b == nil {
+		return true
+	}
+	if (a != nil) != (b != nil) {
+		return false
+	}
+	if len(a) != len(b) {
+		return false
+	}
+	for k := range a {
+		if !a[k].Equal(b[k], false) {
+			return false
+		}
+	}
+	return true
+}
+
 func linkByIndex(index int) (netlink.Link, error) {
 func linkByIndex(index int) (netlink.Link, error) {
 	link, err := netlink.LinkByIndex(index)
 	link, err := netlink.LinkByIndex(index)
 	if err != nil {
 	if err != nil {
@@ -772,21 +768,30 @@ func linkByIndex(index int) (netlink.Link, error) {
 	return link, nil
 	return link, nil
 }
 }
 
 
-// updateNATEndpoints ensures that nodes and peers behind NAT update
-// their endpoints from the WireGuard configuration so they can roam.
-func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
+// discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
+func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf, logger log.Logger) map[string]*wireguard.Endpoint {
+	natEndpoints := make(map[string]*wireguard.Endpoint)
 	keys := make(map[string]*wireguard.Peer)
 	keys := make(map[string]*wireguard.Peer)
 	for i := range conf.Peers {
 	for i := range conf.Peers {
 		keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
 		keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
 	}
 	}
 	for _, n := range nodes {
 	for _, n := range nodes {
 		if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
 		if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
-			n.Endpoint = peer.Endpoint
+			level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
+			// Should check location leader but only available in topology ... or have topology handle that list
+			// Better check wg latest-handshake
+			if !n.Endpoint.Equal(peer.Endpoint, false) {
+				natEndpoints[string(n.Key)] = peer.Endpoint
+			}
 		}
 		}
 	}
 	}
 	for _, p := range peers {
 	for _, p := range peers {
 		if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
 		if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
-			p.Endpoint = peer.Endpoint
+			if !p.Endpoint.Equal(peer.Endpoint, false) {
+				natEndpoints[string(p.PublicKey)] = peer.Endpoint
+			}
 		}
 		}
 	}
 	}
+	level.Debug(logger).Log("msg", "Discovered WireGuard NAT Endpoints", "DiscoveredEndpoints", natEndpoints)
+	return natEndpoints
 }
 }

+ 44 - 14
pkg/mesh/topology.go

@@ -55,12 +55,15 @@ type Topology struct {
 	// the IP is the 0th address in the subnet, i.e. the CIDR
 	// the IP is the 0th address in the subnet, i.e. the CIDR
 	// is equal to the Kilo subnet.
 	// is equal to the Kilo subnet.
 	wireGuardCIDR *net.IPNet
 	wireGuardCIDR *net.IPNet
+	// discoveredEndpoints is the updated map of valid discovered Endpoints
+	discoveredEndpoints map[string]*wireguard.Endpoint
 }
 }
 
 
 type segment struct {
 type segment struct {
-	allowedIPs []*net.IPNet
-	endpoint   *wireguard.Endpoint
-	key        []byte
+	allowedIPs          []*net.IPNet
+	endpoint            *wireguard.Endpoint
+	key                 []byte
+	persistentKeepalive int
 	// Location is the logical location of this segment.
 	// Location is the logical location of this segment.
 	location string
 	location string
 
 
@@ -106,7 +109,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 		localLocation = nodeLocationPrefix + hostname
 		localLocation = nodeLocationPrefix + hostname
 	}
 	}
 
 
-	t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet}
+	t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint)}
 	for location := range topoMap {
 	for location := range topoMap {
 		// Sort the location so the result is stable.
 		// Sort the location so the result is stable.
 		sort.Slice(topoMap[location], func(i, j int) bool {
 		sort.Slice(topoMap[location], func(i, j int) bool {
@@ -134,14 +137,15 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 			hostnames = append(hostnames, node.Name)
 			hostnames = append(hostnames, node.Name)
 		}
 		}
 		t.segments = append(t.segments, &segment{
 		t.segments = append(t.segments, &segment{
-			allowedIPs: allowedIPs,
-			endpoint:   topoMap[location][leader].Endpoint,
-			key:        topoMap[location][leader].Key,
-			location:   location,
-			cidrs:      cidrs,
-			hostnames:  hostnames,
-			leader:     leader,
-			privateIPs: privateIPs,
+			allowedIPs:          allowedIPs,
+			endpoint:            topoMap[location][leader].Endpoint,
+			key:                 topoMap[location][leader].Key,
+			persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
+			location:            location,
+			cidrs:               cidrs,
+			hostnames:           hostnames,
+			leader:              leader,
+			privateIPs:          privateIPs,
 		})
 		})
 	}
 	}
 	// Sort the Topology segments so the result is stable.
 	// Sort the Topology segments so the result is stable.
@@ -159,6 +163,10 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 	// We need to defensively deduplicate peer allowed IPs. If two peers claim the same IP,
 	// We need to defensively deduplicate peer allowed IPs. If two peers claim the same IP,
 	// the WireGuard configuration could flap, causing the interface to churn.
 	// the WireGuard configuration could flap, causing the interface to churn.
 	t.peers = deduplicatePeerIPs(t.peers)
 	t.peers = deduplicatePeerIPs(t.peers)
+	// Copy the host node DiscoveredEndpoints in the topology as a starting point.
+	for key := range nodes[hostname].DiscoveredEndpoints {
+		t.discoveredEndpoints[key] = nodes[hostname].DiscoveredEndpoints[key]
+	}
 	// Allocate IPs to the segment leaders in a stable, coordination-free manner.
 	// Allocate IPs to the segment leaders in a stable, coordination-free manner.
 	a := newAllocator(*subnet)
 	a := newAllocator(*subnet)
 	for _, segment := range t.segments {
 	for _, segment := range t.segments {
@@ -171,11 +179,33 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 		if t.leader && segment.location == t.location {
 		if t.leader && segment.location == t.location {
 			t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
 			t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
 		}
 		}
+
+		// Now that the topology is ordered, update the discoveredEndpoints map
+		// add new ones by going through the ordered topology: segments, nodes
+		for _, node := range topoMap[segment.location] {
+			for key := range node.DiscoveredEndpoints {
+				if _, ok := t.discoveredEndpoints[key]; !ok {
+					t.discoveredEndpoints[key] = node.DiscoveredEndpoints[key]
+				}
+			}
+		}
 	}
 	}
 
 
 	return &t, nil
 	return &t, nil
 }
 }
 
 
+func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, persistentKeepalive int) *wireguard.Endpoint {
+	// Do not update non-nat peers
+	if persistentKeepalive == 0 {
+		return endpoint
+	}
+	e, ok := t.discoveredEndpoints[string(key)]
+	if ok {
+		return e
+	}
+	return endpoint
+}
+
 // Conf generates a WireGuard configuration file for a given Topology.
 // Conf generates a WireGuard configuration file for a given Topology.
 func (t *Topology) Conf() *wireguard.Conf {
 func (t *Topology) Conf() *wireguard.Conf {
 	c := &wireguard.Conf{
 	c := &wireguard.Conf{
@@ -190,7 +220,7 @@ func (t *Topology) Conf() *wireguard.Conf {
 		}
 		}
 		peer := &wireguard.Peer{
 		peer := &wireguard.Peer{
 			AllowedIPs:          s.allowedIPs,
 			AllowedIPs:          s.allowedIPs,
-			Endpoint:            s.endpoint,
+			Endpoint:            t.updateEndpoint(s.endpoint, s.key, s.persistentKeepalive),
 			PersistentKeepalive: t.persistentKeepalive,
 			PersistentKeepalive: t.persistentKeepalive,
 			PublicKey:           s.key,
 			PublicKey:           s.key,
 		}
 		}
@@ -199,7 +229,7 @@ func (t *Topology) Conf() *wireguard.Conf {
 	for _, p := range t.peers {
 	for _, p := range t.peers {
 		peer := &wireguard.Peer{
 		peer := &wireguard.Peer{
 			AllowedIPs:          p.AllowedIPs,
 			AllowedIPs:          p.AllowedIPs,
-			Endpoint:            p.Endpoint,
+			Endpoint:            t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepalive),
 			PersistentKeepalive: t.persistentKeepalive,
 			PersistentKeepalive: t.persistentKeepalive,
 			PresharedKey:        p.PresharedKey,
 			PresharedKey:        p.PresharedKey,
 			PublicKey:           p.PublicKey,
 			PublicKey:           p.PublicKey,

+ 225 - 200
pkg/mesh/topology_test.go

@@ -126,34 +126,37 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
 				wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    logicalLocationPrefix + nodes["a"].Location,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            logicalLocationPrefix + nodes["a"].Location,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    logicalLocationPrefix + nodes["b"].Location,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
-						hostnames:   []string{"b", "c"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            logicalLocationPrefix + nodes["b"].Location,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
+						hostnames:           []string{"b", "c"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w3,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},
@@ -172,34 +175,37 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
 				wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    logicalLocationPrefix + nodes["a"].Location,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            logicalLocationPrefix + nodes["a"].Location,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    logicalLocationPrefix + nodes["b"].Location,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
-						hostnames:   []string{"b", "c"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            logicalLocationPrefix + nodes["b"].Location,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
+						hostnames:           []string{"b", "c"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w3,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},
@@ -218,34 +224,37 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: DefaultKiloSubnet,
 				wireGuardCIDR: DefaultKiloSubnet,
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    logicalLocationPrefix + nodes["a"].Location,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            logicalLocationPrefix + nodes["a"].Location,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    logicalLocationPrefix + nodes["b"].Location,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
-						hostnames:   []string{"b", "c"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            logicalLocationPrefix + nodes["b"].Location,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet, nodes["c"].Subnet},
+						hostnames:           []string{"b", "c"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w3,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},
@@ -264,44 +273,48 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
 				wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    nodeLocationPrefix + nodes["a"].Name,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["a"].Name,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    nodeLocationPrefix + nodes["b"].Name,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet},
-						hostnames:   []string{"b"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["b"].Name,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet},
+						hostnames:           []string{"b"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["c"].Endpoint,
-						key:         nodes["c"].Key,
-						location:    nodeLocationPrefix + nodes["c"].Name,
-						cidrs:       []*net.IPNet{nodes["c"].Subnet},
-						hostnames:   []string{"c"},
-						privateIPs:  []net.IP{nodes["c"].InternalIP.IP},
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["c"].Endpoint,
+						key:                 nodes["c"].Key,
+						persistentKeepalive: nodes["c"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["c"].Name,
+						cidrs:               []*net.IPNet{nodes["c"].Subnet},
+						hostnames:           []string{"c"},
+						privateIPs:          []net.IP{nodes["c"].InternalIP.IP},
+						wireGuardIP:         w3,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w4,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w4,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},
@@ -320,44 +333,48 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
 				wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    nodeLocationPrefix + nodes["a"].Name,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["a"].Name,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    nodeLocationPrefix + nodes["b"].Name,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet},
-						hostnames:   []string{"b"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["b"].Name,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet},
+						hostnames:           []string{"b"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["c"].Endpoint,
-						key:         nodes["c"].Key,
-						location:    nodeLocationPrefix + nodes["c"].Name,
-						cidrs:       []*net.IPNet{nodes["c"].Subnet},
-						hostnames:   []string{"c"},
-						privateIPs:  []net.IP{nodes["c"].InternalIP.IP},
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["c"].Endpoint,
+						key:                 nodes["c"].Key,
+						persistentKeepalive: nodes["c"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["c"].Name,
+						cidrs:               []*net.IPNet{nodes["c"].Subnet},
+						hostnames:           []string{"c"},
+						privateIPs:          []net.IP{nodes["c"].InternalIP.IP},
+						wireGuardIP:         w3,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w4,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w4,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},
@@ -376,44 +393,48 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
 				wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    nodeLocationPrefix + nodes["a"].Name,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["a"].Name,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    nodeLocationPrefix + nodes["b"].Name,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet},
-						hostnames:   []string{"b"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["b"].Name,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet},
+						hostnames:           []string{"b"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["c"].Endpoint,
-						key:         nodes["c"].Key,
-						location:    nodeLocationPrefix + nodes["c"].Name,
-						cidrs:       []*net.IPNet{nodes["c"].Subnet},
-						hostnames:   []string{"c"},
-						privateIPs:  []net.IP{nodes["c"].InternalIP.IP},
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["c"].Endpoint,
+						key:                 nodes["c"].Key,
+						persistentKeepalive: nodes["c"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["c"].Name,
+						cidrs:               []*net.IPNet{nodes["c"].Subnet},
+						hostnames:           []string{"c"},
+						privateIPs:          []net.IP{nodes["c"].InternalIP.IP},
+						wireGuardIP:         w3,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w4,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w4,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},
@@ -432,44 +453,48 @@ func TestNewTopology(t *testing.T) {
 				wireGuardCIDR: &net.IPNet{IP: w4, Mask: net.CIDRMask(16, 32)},
 				wireGuardCIDR: &net.IPNet{IP: w4, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
 				segments: []*segment{
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["a"].Endpoint,
-						key:         nodes["a"].Key,
-						location:    nodeLocationPrefix + nodes["a"].Name,
-						cidrs:       []*net.IPNet{nodes["a"].Subnet},
-						hostnames:   []string{"a"},
-						privateIPs:  []net.IP{nodes["a"].InternalIP.IP},
-						wireGuardIP: w1,
+						allowedIPs:          []*net.IPNet{nodes["a"].Subnet, nodes["a"].InternalIP, {IP: w1, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["a"].Endpoint,
+						key:                 nodes["a"].Key,
+						persistentKeepalive: nodes["a"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["a"].Name,
+						cidrs:               []*net.IPNet{nodes["a"].Subnet},
+						hostnames:           []string{"a"},
+						privateIPs:          []net.IP{nodes["a"].InternalIP.IP},
+						wireGuardIP:         w1,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["b"].Endpoint,
-						key:         nodes["b"].Key,
-						location:    nodeLocationPrefix + nodes["b"].Name,
-						cidrs:       []*net.IPNet{nodes["b"].Subnet},
-						hostnames:   []string{"b"},
-						privateIPs:  []net.IP{nodes["b"].InternalIP.IP},
-						wireGuardIP: w2,
+						allowedIPs:          []*net.IPNet{nodes["b"].Subnet, nodes["b"].InternalIP, {IP: w2, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["b"].Endpoint,
+						key:                 nodes["b"].Key,
+						persistentKeepalive: nodes["b"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["b"].Name,
+						cidrs:               []*net.IPNet{nodes["b"].Subnet},
+						hostnames:           []string{"b"},
+						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
+						wireGuardIP:         w2,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["c"].Endpoint,
-						key:         nodes["c"].Key,
-						location:    nodeLocationPrefix + nodes["c"].Name,
-						cidrs:       []*net.IPNet{nodes["c"].Subnet},
-						hostnames:   []string{"c"},
-						privateIPs:  []net.IP{nodes["c"].InternalIP.IP},
-						wireGuardIP: w3,
+						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["c"].Endpoint,
+						key:                 nodes["c"].Key,
+						persistentKeepalive: nodes["c"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["c"].Name,
+						cidrs:               []*net.IPNet{nodes["c"].Subnet},
+						hostnames:           []string{"c"},
+						privateIPs:          []net.IP{nodes["c"].InternalIP.IP},
+						wireGuardIP:         w3,
 					},
 					},
 					{
 					{
-						allowedIPs:  []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
-						endpoint:    nodes["d"].Endpoint,
-						key:         nodes["d"].Key,
-						location:    nodeLocationPrefix + nodes["d"].Name,
-						cidrs:       []*net.IPNet{nodes["d"].Subnet},
-						hostnames:   []string{"d"},
-						privateIPs:  nil,
-						wireGuardIP: w4,
+						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w4, Mask: net.CIDRMask(32, 32)}},
+						endpoint:            nodes["d"].Endpoint,
+						key:                 nodes["d"].Key,
+						persistentKeepalive: nodes["d"].PersistentKeepalive,
+						location:            nodeLocationPrefix + nodes["d"].Name,
+						cidrs:               []*net.IPNet{nodes["d"].Subnet},
+						hostnames:           []string{"d"},
+						privateIPs:          nil,
+						wireGuardIP:         w4,
 					},
 					},
 				},
 				},
 				peers: []*Peer{peers["a"], peers["b"]},
 				peers: []*Peer{peers["a"], peers["b"]},

+ 33 - 14
pkg/wireguard/conf.go

@@ -95,6 +95,38 @@ func (e *Endpoint) String() string {
 	return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
 	return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
 }
 }
 
 
+// Equal compares two endpoints.
+func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
+	if (e == nil) != (b == nil) {
+		return false
+	}
+	if e != nil {
+		if e.Port != b.Port {
+			return false
+		}
+		if DNSFirst {
+			// Check the DNS name first if it was resolved.
+			if e.DNS != b.DNS {
+				return false
+			}
+			if e.DNS == "" && !e.IP.Equal(b.IP) {
+				return false
+			}
+		} else {
+			// IPs take priority, so check them first.
+			if !e.IP.Equal(b.IP) {
+				return false
+			}
+			// Only check the DNS name if the IP is empty.
+			if e.IP == nil && e.DNS != b.DNS {
+				return false
+			}
+		}
+	}
+
+	return true
+}
+
 // DNSOrIP represents either a DNS name or an IP address.
 // DNSOrIP represents either a DNS name or an IP address.
 // IPs, as they are more specific, are preferred.
 // IPs, as they are more specific, are preferred.
 type DNSOrIP struct {
 type DNSOrIP struct {
@@ -309,22 +341,9 @@ func (c *Conf) Equal(b *Conf) bool {
 				return false
 				return false
 			}
 			}
 		}
 		}
-		if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
+		if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
 			return false
 			return false
 		}
 		}
-		if c.Peers[i].Endpoint != nil {
-			if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
-				return false
-			}
-			// IPs take priority, so check them first.
-			if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
-				return false
-			}
-			// Only check the DNS name if the IP is empty.
-			if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
-				return false
-			}
-		}
 		if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
 		if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
 			return false
 			return false
 		}
 		}

+ 105 - 0
pkg/wireguard/conf_test.go

@@ -15,6 +15,7 @@
 package wireguard
 package wireguard
 
 
 import (
 import (
+	"net"
 	"testing"
 	"testing"
 )
 )
 
 
@@ -203,3 +204,107 @@ func TestCompareConf(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestCompareEndpoint(t *testing.T) {
+	for _, tc := range []struct {
+		name     string
+		a        *Endpoint
+		b        *Endpoint
+		dnsFirst bool
+		out      bool
+	}{
+		{
+			name: "both nil",
+			a:    nil,
+			b:    nil,
+			out:  true,
+		},
+		{
+			name: "a nil",
+			a:    nil,
+			b:    &Endpoint{},
+			out:  false,
+		},
+		{
+			name: "b nil",
+			a:    &Endpoint{},
+			b:    nil,
+			out:  false,
+		},
+		{
+			name: "zero",
+			a:    &Endpoint{},
+			b:    &Endpoint{},
+			out:  true,
+		},
+		{
+			name: "diff port",
+			a:    &Endpoint{Port: 1234},
+			b:    &Endpoint{Port: 5678},
+			out:  false,
+		},
+		{
+			name: "same IP",
+			a:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
+			b:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
+			out:  true,
+		},
+		{
+			name: "diff IP",
+			a:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1")}},
+			b:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2")}},
+			out:  false,
+		},
+		{
+			name: "same IP ignore DNS",
+			a:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
+			b:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "b"}},
+			out:  true,
+		},
+		{
+			name: "no IP check DNS",
+			a:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
+			b:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
+			out:  false,
+		},
+		{
+			name: "no IP check DNS (same)",
+			a:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
+			b:    &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
+			out:  true,
+		},
+		{
+			name:     "DNS first, ignore IP",
+			a:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
+			b:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
+			dnsFirst: true,
+			out:      true,
+		},
+		{
+			name:     "DNS first",
+			a:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
+			b:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
+			dnsFirst: true,
+			out:      false,
+		},
+		{
+			name:     "DNS first, no DNS compare IP",
+			a:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
+			b:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
+			dnsFirst: true,
+			out:      false,
+		},
+		{
+			name:     "DNS first, no DNS compare IP (same)",
+			a:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
+			b:        &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
+			dnsFirst: true,
+			out:      true,
+		},
+	} {
+		equal := tc.a.Equal(tc.b, tc.dnsFirst)
+		if equal != tc.out {
+			t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
+		}
+	}
+}