|
|
@@ -529,7 +529,9 @@ func (m *Mesh) applyTopology() {
|
|
|
if !m.nodes[k].Ready() {
|
|
|
continue
|
|
|
}
|
|
|
- nodes[k] = m.nodes[k]
|
|
|
+ // Make a shallow copy of the node.
|
|
|
+ node := *m.nodes[k]
|
|
|
+ nodes[k] = &node
|
|
|
readyNodes++
|
|
|
}
|
|
|
// Ensure only ready nodes are considered.
|
|
|
@@ -539,7 +541,9 @@ func (m *Mesh) applyTopology() {
|
|
|
if !m.peers[k].Ready() {
|
|
|
continue
|
|
|
}
|
|
|
- peers[k] = m.peers[k]
|
|
|
+ // Make a shallow copy of the peer.
|
|
|
+ peer := *m.peers[k]
|
|
|
+ peers[k] = &peer
|
|
|
readyPeers++
|
|
|
}
|
|
|
m.nodesGuage.Set(readyNodes)
|
|
|
@@ -548,6 +552,22 @@ func (m *Mesh) applyTopology() {
|
|
|
if nodes[m.hostname] == nil {
|
|
|
return
|
|
|
}
|
|
|
+ // Find the Kilo interface name.
|
|
|
+ link, err := linkByIndex(m.kiloIface)
|
|
|
+ if err != nil {
|
|
|
+ level.Error(m.logger).Log("error", err)
|
|
|
+ m.errorCounter.WithLabelValues("apply").Inc()
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // Find the old configuration.
|
|
|
+ oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
|
|
|
+ if err != nil {
|
|
|
+ level.Error(m.logger).Log("error", err)
|
|
|
+ m.errorCounter.WithLabelValues("apply").Inc()
|
|
|
+ return
|
|
|
+ }
|
|
|
+ oldConf := wireguard.Parse(oldConfRaw)
|
|
|
+ updateNATEndpoints(nodes, peers, oldConf)
|
|
|
t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
|
|
|
if err != nil {
|
|
|
level.Error(m.logger).Log("error", err)
|
|
|
@@ -582,7 +602,6 @@ func (m *Mesh) applyTopology() {
|
|
|
}
|
|
|
}
|
|
|
ipRules = append(ipRules, m.enc.Rules(cidrs)...)
|
|
|
-
|
|
|
// If we are handling local routes, ensure the local
|
|
|
// tunnel has an IP address.
|
|
|
if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
|
|
|
@@ -596,28 +615,15 @@ func (m *Mesh) applyTopology() {
|
|
|
m.errorCounter.WithLabelValues("apply").Inc()
|
|
|
return
|
|
|
}
|
|
|
- // Find the Kilo interface name.
|
|
|
- link, err := linkByIndex(m.kiloIface)
|
|
|
- if err != nil {
|
|
|
- level.Error(m.logger).Log("error", err)
|
|
|
- m.errorCounter.WithLabelValues("apply").Inc()
|
|
|
- return
|
|
|
- }
|
|
|
if t.leader {
|
|
|
if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
|
|
|
level.Error(m.logger).Log("error", err)
|
|
|
m.errorCounter.WithLabelValues("apply").Inc()
|
|
|
return
|
|
|
}
|
|
|
- oldConf, err := wireguard.ShowConf(link.Attrs().Name)
|
|
|
- if err != nil {
|
|
|
- level.Error(m.logger).Log("error", err)
|
|
|
- m.errorCounter.WithLabelValues("apply").Inc()
|
|
|
- return
|
|
|
- }
|
|
|
// Setting the WireGuard configuration interrupts existing connections
|
|
|
// so only set the configuration if it has changed.
|
|
|
- equal := conf.EqualWithPeerCheck(wireguard.Parse(oldConf), peersAreEqualIgnoreNAT)
|
|
|
+ equal := conf.Equal(oldConf)
|
|
|
if !equal {
|
|
|
level.Info(m.logger).Log("msg", "WireGuard configurations are different")
|
|
|
if err := wireguard.SetConf(link.Attrs().Name, ConfPath); err != nil {
|
|
|
@@ -814,41 +820,6 @@ func peersAreEqual(a, b *Peer) bool {
|
|
|
return string(a.PublicKey) == string(b.PublicKey) && a.PersistentKeepalive == b.PersistentKeepalive
|
|
|
}
|
|
|
|
|
|
-// Basic nil checks and checking the lengths of the allowed IPs is
|
|
|
-// done by the WireGuard package.
|
|
|
-func peersAreEqualIgnoreNAT(a, b *wireguard.Peer) bool {
|
|
|
- for j := range a.AllowedIPs {
|
|
|
- if a.AllowedIPs[j].String() != b.AllowedIPs[j].String() {
|
|
|
- return false
|
|
|
- }
|
|
|
- }
|
|
|
- if a.PersistentKeepalive != b.PersistentKeepalive || !bytes.Equal(a.PublicKey, b.PublicKey) {
|
|
|
- return false
|
|
|
- }
|
|
|
- // If a persistent keepalive is set, then the peer is behind NAT
|
|
|
- // and we want to ignore changes in endpoints, since it may roam.
|
|
|
- if a.PersistentKeepalive != 0 {
|
|
|
- return true
|
|
|
- }
|
|
|
- if (a.Endpoint == nil) != (b.Endpoint == nil) {
|
|
|
- return false
|
|
|
- }
|
|
|
- if a.Endpoint != nil {
|
|
|
- if a.Endpoint.Port != b.Endpoint.Port {
|
|
|
- return false
|
|
|
- }
|
|
|
- // IPs take priority, so check them first.
|
|
|
- if !a.Endpoint.IP.Equal(b.Endpoint.IP) {
|
|
|
- return false
|
|
|
- }
|
|
|
- // Only check the DNS name if the IP is empty.
|
|
|
- if a.Endpoint.IP == nil && a.Endpoint.DNS != b.Endpoint.DNS {
|
|
|
- return false
|
|
|
- }
|
|
|
- }
|
|
|
- return true
|
|
|
-}
|
|
|
-
|
|
|
func ipNetsEqual(a, b *net.IPNet) bool {
|
|
|
if a == nil && b == nil {
|
|
|
return true
|
|
|
@@ -888,3 +859,22 @@ func linkByIndex(index int) (netlink.Link, error) {
|
|
|
}
|
|
|
return link, nil
|
|
|
}
|
|
|
+
|
|
|
+// updateNATEndpoints ensures that nodes and peers behind NAT update
|
|
|
+// their endpoints from the WireGuard configuration so they can roam.
|
|
|
+func updateNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wireguard.Conf) {
|
|
|
+ keys := make(map[string]*wireguard.Peer)
|
|
|
+ for i := range conf.Peers {
|
|
|
+ keys[string(conf.Peers[i].PublicKey)] = conf.Peers[i]
|
|
|
+ }
|
|
|
+ for _, n := range nodes {
|
|
|
+ if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
|
|
|
+ n.Endpoint = peer.Endpoint
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for _, p := range peers {
|
|
|
+ if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
|
|
|
+ p.Endpoint = peer.Endpoint
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|