| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867 |
- // Copyright 2021 the Kilo authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //go:build linux
- // +build linux
- package mesh
- import (
- "bytes"
- "context"
- "fmt"
- "net"
- "os"
- "sync"
- "time"
- "github.com/go-kit/kit/log"
- "github.com/go-kit/kit/log/level"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/vishvananda/netlink"
- "golang.zx2c4.com/wireguard/wgctrl"
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/squat/kilo/pkg/encapsulation"
- "github.com/squat/kilo/pkg/iproute"
- "github.com/squat/kilo/pkg/iptables"
- "github.com/squat/kilo/pkg/route"
- "github.com/squat/kilo/pkg/wireguard"
- )
- const (
- // kiloPath is the directory where Kilo stores its configuration.
- kiloPath = "/var/lib/kilo"
- // privateKeyPath is the filepath where the WireGuard private key is stored.
- privateKeyPath = kiloPath + "/key"
- )
- // Mesh is able to create Kilo network meshes.
- type Mesh struct {
- Backend
- cleanup bool
- cleanUpIface bool
- cni bool
- cniPath string
- enc encapsulation.Encapsulator
- externalIP *net.IPNet
- granularity Granularity
- hostname string
- internalIP *net.IPNet
- ipTables *iptables.Controller
- kiloIface int
- kiloIfaceName string
- local bool
- mtu uint
- autoMTU bool
- port int
- priv wgtypes.Key
- privIface int
- pub wgtypes.Key
- resyncPeriod time.Duration
- iptablesForwardRule bool
- serviceCIDRs []*net.IPNet
- subnet *net.IPNet
- table *route.Table
- wireGuardIP *net.IPNet
- // nodes and peers are mutable fields in the struct
- // and need to be guarded.
- nodes map[string]*Node
- peers map[string]*Peer
- mu sync.Mutex
- errorCounter *prometheus.CounterVec
- leaderGuage prometheus.Gauge
- nodesGuage prometheus.Gauge
- peersGuage prometheus.Gauge
- reconcileCounter prometheus.Counter
- logger log.Logger
- }
- // New returns a new Mesh instance.
- func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanup bool, cleanUpIface bool, createIface bool, mtu uint, autoMTU bool, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, allowedInternalCIDRs []*net.IPNet, serviceCIDRs []*net.IPNet, logger log.Logger, registerer prometheus.Registerer) (*Mesh, error) {
- if err := os.MkdirAll(kiloPath, 0700); err != nil {
- return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
- }
- privateB, err := os.ReadFile(privateKeyPath)
- if err != nil && !os.IsNotExist(err) {
- return nil, fmt.Errorf("failed to read private key file: %v", err)
- }
- privateB = bytes.Trim(privateB, "\n")
- private, err := wgtypes.ParseKey(string(privateB))
- if err != nil {
- _ = level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
- if private, err = wgtypes.GeneratePrivateKey(); err != nil {
- return nil, err
- }
- if err := os.WriteFile(privateKeyPath, []byte(private.String()), 0600); err != nil {
- return nil, fmt.Errorf("failed to write private key to disk: %v", err)
- }
- }
- public := private.PublicKey()
- if err != nil {
- return nil, err
- }
- cniIndex, err := cniDeviceIndex()
- if err != nil {
- return nil, fmt.Errorf("failed to query netlink for CNI device: %v", err)
- }
- var kiloIface int
- if createIface {
- link, err := netlink.LinkByName(iface)
- if err != nil {
- kiloIface, _, err = wireguard.New(iface, mtu)
- if err != nil {
- return nil, fmt.Errorf("failed to create WireGuard interface: %v", err)
- }
- } else {
- kiloIface = link.Attrs().Index
- }
- } else {
- link, err := netlink.LinkByName(iface)
- if err != nil {
- return nil, fmt.Errorf("failed to get interface index: %v", err)
- }
- kiloIface = link.Attrs().Index
- }
- privateIP, publicIP, err := getIP(hostname, allowedInternalCIDRs, kiloIface, enc.Index(), cniIndex)
- if err != nil {
- return nil, fmt.Errorf("failed to find public IP: %v", err)
- }
- var privIface int
- if privateIP != nil {
- ifaces, err := interfacesForIP(privateIP)
- if err != nil {
- return nil, fmt.Errorf("failed to find interface for private IP: %v", err)
- }
- privIface = ifaces[0].Index
- if enc.Strategy() != encapsulation.Never {
- if err := enc.Init(privIface); err != nil {
- return nil, fmt.Errorf("failed to initialize encapsulator: %v", err)
- }
- }
- _ = level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
- } else {
- enc = encapsulation.Noop(enc.Strategy())
- _ = level.Debug(logger).Log("msg", "running without a private IP address")
- }
- if autoMTU {
- mtu = detectMTU(logger)
- }
- var externalIP *net.IPNet
- if prioritisePrivateAddr && privateIP != nil {
- externalIP = privateIP
- } else {
- externalIP = publicIP
- }
- _ = level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
- ipTables, err := iptables.New(iptables.WithRegisterer(registerer), iptables.WithLogger(log.With(logger, "component", "iptables")), iptables.WithResyncPeriod(resyncPeriod))
- if err != nil {
- return nil, fmt.Errorf("failed to IP tables controller: %v", err)
- }
- mesh := Mesh{
- Backend: backend,
- cleanup: cleanup,
- cleanUpIface: cleanUpIface,
- cni: cni,
- cniPath: cniPath,
- enc: enc,
- externalIP: externalIP,
- granularity: granularity,
- hostname: hostname,
- internalIP: privateIP,
- ipTables: ipTables,
- kiloIface: kiloIface,
- kiloIfaceName: iface,
- nodes: make(map[string]*Node),
- peers: make(map[string]*Peer),
- port: port,
- priv: private,
- privIface: privIface,
- pub: public,
- resyncPeriod: resyncPeriod,
- iptablesForwardRule: iptablesForwardRule,
- local: local,
- mtu: mtu,
- autoMTU: autoMTU,
- serviceCIDRs: serviceCIDRs,
- subnet: subnet,
- table: route.NewTable(),
- errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
- Name: "kilo_errors_total",
- Help: "Number of errors that occurred while administering the mesh.",
- }, []string{"event"}),
- leaderGuage: prometheus.NewGauge(prometheus.GaugeOpts{
- Name: "kilo_leader",
- Help: "Leadership status of the node.",
- }),
- nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
- Name: "kilo_nodes",
- Help: "Number of nodes in the mesh.",
- }),
- peersGuage: prometheus.NewGauge(prometheus.GaugeOpts{
- Name: "kilo_peers",
- Help: "Number of peers in the mesh.",
- }),
- reconcileCounter: prometheus.NewCounter(prometheus.CounterOpts{
- Name: "kilo_reconciles_total",
- Help: "Number of reconciliation attempts.",
- }),
- logger: logger,
- }
- registerer.MustRegister(
- mesh.errorCounter,
- mesh.leaderGuage,
- mesh.nodesGuage,
- mesh.peersGuage,
- mesh.reconcileCounter,
- )
- return &mesh, nil
- }
- // Run starts the mesh.
- func (m *Mesh) Run(ctx context.Context) error {
- if err := m.Nodes().Init(ctx); err != nil {
- return fmt.Errorf("failed to initialize node backend: %v", err)
- }
- // Try to set the CNI config quickly.
- if m.cni {
- if n, err := m.Nodes().Get(m.hostname); err == nil {
- m.nodes[m.hostname] = n
- m.updateCNIConfig()
- } else {
- _ = level.Warn(m.logger).Log("error", fmt.Errorf("failed to get node %q: %v", m.hostname, err))
- }
- }
- if err := m.Peers().Init(ctx); err != nil {
- return fmt.Errorf("failed to initialize peer backend: %v", err)
- }
- ipTablesErrors, err := m.ipTables.Run(ctx.Done())
- if err != nil {
- return fmt.Errorf("failed to watch for IP tables updates: %v", err)
- }
- routeErrors, err := m.table.Run(ctx.Done())
- if err != nil {
- return fmt.Errorf("failed to watch for route table updates: %v", err)
- }
- go func() {
- for {
- var err error
- select {
- case err = <-ipTablesErrors:
- case err = <-routeErrors:
- case <-ctx.Done():
- return
- }
- if err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("run").Inc()
- }
- }
- }()
- if m.cleanup {
- defer m.cleanUp()
- }
- resync := time.NewTimer(m.resyncPeriod)
- checkIn := time.NewTimer(checkInPeriod)
- nw := m.Nodes().Watch()
- pw := m.Peers().Watch()
- var ne *NodeEvent
- var pe *PeerEvent
- for {
- select {
- case ne = <-nw:
- m.syncNodes(ctx, ne)
- case pe = <-pw:
- m.syncPeers(pe)
- case <-checkIn.C:
- m.checkIn(ctx)
- checkIn.Reset(checkInPeriod)
- case <-resync.C:
- if m.cni {
- m.updateCNIConfig()
- }
- m.applyTopology()
- resync.Reset(m.resyncPeriod)
- case <-ctx.Done():
- return nil
- }
- }
- }
- func (m *Mesh) syncNodes(ctx context.Context, e *NodeEvent) {
- logger := log.With(m.logger, "event", e.Type)
- _ = level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
- if isSelf(m.hostname, e.Node) {
- _ = level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
- m.handleLocal(ctx, e.Node)
- return
- }
- var diff bool
- m.mu.Lock()
- if !e.Node.Ready() {
- // Trace non ready nodes with their presence in the mesh.
- _, ok := m.nodes[e.Node.Name]
- _ = level.Debug(logger).Log("msg", "received non ready node", "node", e.Node, "in-mesh", ok)
- }
- switch e.Type {
- case AddEvent:
- fallthrough
- case UpdateEvent:
- if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
- diff = true
- }
- // Even if the nodes are the same,
- // overwrite the old node to update the timestamp.
- m.nodes[e.Node.Name] = e.Node
- case DeleteEvent:
- delete(m.nodes, e.Node.Name)
- diff = true
- }
- m.mu.Unlock()
- if diff {
- _ = level.Info(logger).Log("node", e.Node)
- m.applyTopology()
- }
- }
- func (m *Mesh) syncPeers(e *PeerEvent) {
- logger := log.With(m.logger, "event", e.Type)
- _ = level.Debug(logger).Log("msg", "syncing peers", "event", e.Type)
- var diff bool
- m.mu.Lock()
- // Peers are indexed by public key.
- key := e.Peer.PublicKey.String()
- if !e.Peer.Ready() {
- // Trace non ready peer with their presence in the mesh.
- _, ok := m.peers[key]
- _ = level.Debug(logger).Log("msg", "received non ready peer", "peer", e.Peer, "in-mesh", ok)
- }
- switch e.Type {
- case AddEvent:
- fallthrough
- case UpdateEvent:
- if e.Old != nil && key != e.Old.PublicKey.String() {
- delete(m.peers, e.Old.PublicKey.String())
- diff = true
- }
- if !peersAreEqual(m.peers[key], e.Peer) {
- m.peers[key] = e.Peer
- diff = true
- }
- case DeleteEvent:
- delete(m.peers, key)
- diff = true
- }
- m.mu.Unlock()
- if diff {
- _ = level.Info(logger).Log("peer", e.Peer)
- m.applyTopology()
- }
- }
- // checkIn will try to update the local node's LastSeen timestamp
- // in the backend.
- func (m *Mesh) checkIn(ctx context.Context) {
- m.mu.Lock()
- defer m.mu.Unlock()
- n := m.nodes[m.hostname]
- if n == nil {
- _ = level.Debug(m.logger).Log("msg", "no local node found in backend")
- return
- }
- oldTime := n.LastSeen
- n.LastSeen = time.Now().Unix()
- if err := m.Nodes().Set(ctx, m.hostname, n); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
- m.errorCounter.WithLabelValues("checkin").Inc()
- // Revert time.
- n.LastSeen = oldTime
- return
- }
- _ = level.Debug(m.logger).Log("msg", "successfully checked in local node in backend")
- }
- func (m *Mesh) handleLocal(ctx context.Context, n *Node) {
- // Allow the IPs to be overridden.
- if !n.Endpoint.Ready() {
- e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
- _ = level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String())
- n.Endpoint = e
- }
- if n.InternalIP == nil && !n.NoInternalIP {
- n.InternalIP = m.internalIP
- }
- // Compare the given node to the calculated local node.
- // Take leader, location, and subnet from the argument, as these
- // are not determined by kilo.
- local := &Node{
- Endpoint: n.Endpoint,
- Key: m.pub,
- NoInternalIP: n.NoInternalIP,
- InternalIP: n.InternalIP,
- LastSeen: time.Now().Unix(),
- Leader: n.Leader,
- Location: n.Location,
- Name: m.hostname,
- PersistentKeepalive: n.PersistentKeepalive,
- Subnet: n.Subnet,
- WireGuardIP: m.wireGuardIP,
- DiscoveredEndpoints: n.DiscoveredEndpoints,
- AllowedLocationIPs: n.AllowedLocationIPs,
- Granularity: m.granularity,
- }
- if !nodesAreEqual(n, local) {
- _ = level.Debug(m.logger).Log("msg", "local node differs from backend")
- if err := m.Nodes().Set(ctx, m.hostname, local); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
- m.errorCounter.WithLabelValues("local").Inc()
- return
- }
- _ = level.Debug(m.logger).Log("msg", "successfully reconciled local node against backend")
- }
- m.mu.Lock()
- n = m.nodes[m.hostname]
- if n == nil {
- n = &Node{}
- }
- m.mu.Unlock()
- if !nodesAreEqual(n, local) {
- m.mu.Lock()
- m.nodes[local.Name] = local
- m.mu.Unlock()
- m.applyTopology()
- }
- }
- func (m *Mesh) applyTopology() {
- m.reconcileCounter.Inc()
- m.mu.Lock()
- defer m.mu.Unlock()
- // If we can't resolve an endpoint, then fail and retry later.
- if err := m.resolveEndpoints(); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- // Ensure only ready nodes are considered.
- nodes := make(map[string]*Node)
- var readyNodes float64
- for k := range m.nodes {
- m.nodes[k].Granularity = m.granularity
- if !m.nodes[k].Ready() {
- continue
- }
- // Make it point to the node without copy.
- nodes[k] = m.nodes[k]
- readyNodes++
- }
- // Ensure only ready nodes are considered.
- peers := make(map[string]*Peer)
- var readyPeers float64
- for k := range m.peers {
- if !m.peers[k].Ready() {
- continue
- }
- // Make it point the peer without copy.
- peers[k] = m.peers[k]
- readyPeers++
- }
- m.nodesGuage.Set(readyNodes)
- m.peersGuage.Set(readyPeers)
- // We cannot do anything with the topology until the local node is available.
- if nodes[m.hostname] == nil {
- return
- }
- // Re-detect MTU if auto mode is enabled.
- if m.autoMTU {
- m.mtu = detectMTU(m.logger)
- }
- // Ensure the WireGuard interface has the correct MTU.
- if err := wireguard.SetMTU(m.kiloIface, m.mtu); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Errorf("failed to set MTU on WireGuard interface: %v", err))
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- // Find the Kilo interface name.
- link, err := linkByIndex(m.kiloIface)
- if err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- wgClient, err := wgctrl.New()
- if err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- defer func() { _ = wgClient.Close() }()
- // wgDevice is the current configuration of the wg interface.
- wgDevice, err := wgClient.Device(m.kiloIfaceName)
- if err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
- nodes[m.hostname].DiscoveredEndpoints = natEndpoints
- t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, m.serviceCIDRs, nodes[m.hostname].PersistentKeepalive, m.logger)
- if err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- // Update the node's WireGuard IP.
- if t.leader {
- m.wireGuardIP = t.wireGuardCIDR
- } else {
- m.wireGuardIP = nil
- }
- ipRules := t.Rules(m.cni, m.iptablesForwardRule)
- // If we are handling local routes, ensure the local
- // tunnel has an IP address and IPIP traffic is allowed.
- if m.enc.Strategy() != encapsulation.Never && m.local {
- var cidrs []*net.IPNet
- for _, s := range t.segments {
- // If the location prefix is not logicalLocation, but nodeLocation,
- // we don't need to set any extra rules for encapsulation anyways
- // because traffic will go over WireGuard.
- if s.location == logicalLocationPrefix+nodes[m.hostname].Location {
- for i := range s.privateIPs {
- cidrs = append(cidrs, oneAddressCIDR(s.privateIPs[i]))
- }
- break
- }
- }
- encIpRules := m.enc.Rules(cidrs)
- ipRules = encIpRules.AppendRuleSet(ipRules)
- // If we are handling local routes, ensure the local
- // tunnel has an IP address.
- if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- }
- if err := m.ipTables.Set(ipRules); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- if t.leader {
- m.leaderGuage.Set(1)
- if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- // Setting the WireGuard configuration interrupts existing connections
- // so only set the configuration if it has changed.
- conf := t.Conf()
- equal, diff := conf.Equal(wgDevice)
- if !equal {
- _ = level.Info(m.logger).Log("msg", "WireGuard configurations are different", "diff", diff)
- _ = level.Debug(m.logger).Log("msg", "changing wg config", "config", conf.WGConfig())
- if err := wgClient.ConfigureDevice(m.kiloIfaceName, conf.WGConfig()); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- }
- if err := iproute.Set(m.kiloIface, true); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- } else {
- m.leaderGuage.Set(0)
- _ = level.Debug(m.logger).Log("msg", "local node is not the leader")
- if err := iproute.Set(m.kiloIface, false); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- return
- }
- }
- // We need to add routes last since they may depend
- // on the WireGuard interface.
- routes, rules := t.Routes(link.Attrs().Name, m.kiloIface, m.privIface, m.enc.Index(), m.local, m.enc)
- if err := m.table.Set(routes, rules); err != nil {
- _ = level.Error(m.logger).Log("error", err)
- m.errorCounter.WithLabelValues("apply").Inc()
- }
- }
- func (m *Mesh) cleanUp() {
- if err := m.ipTables.CleanUp(); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
- m.errorCounter.WithLabelValues("cleanUp").Inc()
- }
- if err := m.table.CleanUp(); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
- m.errorCounter.WithLabelValues("cleanUp").Inc()
- }
- if m.cleanUpIface {
- if err := iproute.RemoveInterface(m.kiloIface); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
- m.errorCounter.WithLabelValues("cleanUp").Inc()
- }
- }
- {
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- if err := m.Nodes().CleanUp(ctx, m.hostname); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
- m.errorCounter.WithLabelValues("cleanUp").Inc()
- }
- }
- {
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- if err := m.Peers().CleanUp(ctx, m.hostname); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
- m.errorCounter.WithLabelValues("cleanUp").Inc()
- }
- }
- if err := m.enc.CleanUp(); err != nil {
- _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulator: %v", err))
- m.errorCounter.WithLabelValues("cleanUp").Inc()
- }
- }
- func (m *Mesh) resolveEndpoints() error {
- for k := range m.nodes {
- // Skip unready nodes, since they will not be used
- // in the topology anyways.
- if !m.nodes[k].Ready() {
- continue
- }
- // Resolve the Endpoint
- if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil {
- return err
- }
- }
- for k := range m.peers {
- // Skip unready peers, since they will not be used
- // in the topology anyways.
- if !m.peers[k].Ready() {
- continue
- }
- // Peers may have nil endpoints.
- if !m.peers[k].Endpoint.Ready() {
- continue
- }
- if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil {
- return err
- }
- }
- return nil
- }
- func isSelf(hostname string, node *Node) bool {
- return node != nil && node.Name == hostname
- }
- func nodesAreEqual(a, b *Node) bool {
- if (a != nil) != (b != nil) {
- return false
- }
- if a == b {
- return true
- }
- // Check the DNS name first since this package
- // is doing the DNS resolution.
- if !a.Endpoint.Equal(b.Endpoint, true) {
- return false
- }
- // Ignore LastSeen when comparing equality we want to check if the nodes are
- // equivalent. However, we do want to check if LastSeen has transitioned
- // between valid and invalid.
- return a.Key.String() == b.Key.String() &&
- ipNetsEqual(a.WireGuardIP, b.WireGuardIP) &&
- ipNetsEqual(a.InternalIP, b.InternalIP) &&
- a.Leader == b.Leader &&
- a.Location == b.Location &&
- a.Name == b.Name &&
- subnetsEqual(a.Subnet, b.Subnet) &&
- a.Ready() == b.Ready() &&
- a.PersistentKeepalive == b.PersistentKeepalive &&
- discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) &&
- ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) &&
- a.Granularity == b.Granularity
- }
- func peersAreEqual(a, b *Peer) bool {
- if (a != nil) != (b != nil) {
- return false
- }
- if a == b {
- return true
- }
- // Check the DNS name first since this package
- // is doing the DNS resolution.
- if !a.Endpoint.Equal(b.Endpoint, true) {
- return false
- }
- if len(a.AllowedIPs) != len(b.AllowedIPs) {
- return false
- }
- for i := range a.AllowedIPs {
- if !ipNetsEqual(&a.AllowedIPs[i], &b.AllowedIPs[i]) {
- return false
- }
- }
- return a.PublicKey.String() == b.PublicKey.String() &&
- (a.PresharedKey == nil) == (b.PresharedKey == nil) &&
- (a.PresharedKey == nil || a.PresharedKey.String() == b.PresharedKey.String()) &&
- (a.PersistentKeepaliveInterval == nil) == (b.PersistentKeepaliveInterval == nil) &&
- (a.PersistentKeepaliveInterval == nil || *a.PersistentKeepaliveInterval == *b.PersistentKeepaliveInterval)
- }
- func ipNetsEqual(a, b *net.IPNet) bool {
- if a == nil && b == nil {
- return true
- }
- if (a != nil) != (b != nil) {
- return false
- }
- if a.Mask.String() != b.Mask.String() {
- return false
- }
- return a.IP.Equal(b.IP)
- }
- func ipNetSlicesEqual(a, b []net.IPNet) bool {
- if len(a) != len(b) {
- return false
- }
- for i := range a {
- if !ipNetsEqual(&a[i], &b[i]) {
- return false
- }
- }
- return true
- }
- func subnetsEqual(a, b *net.IPNet) bool {
- if a == nil && b == nil {
- return true
- }
- if (a != nil) != (b != nil) {
- return false
- }
- if a.Mask.String() != b.Mask.String() {
- return false
- }
- if !a.Contains(b.IP) {
- return false
- }
- if !b.Contains(a.IP) {
- return false
- }
- return true
- }
- func udpAddrsEqual(a, b *net.UDPAddr) bool {
- if a == nil && b == nil {
- return true
- }
- if (a != nil) != (b != nil) {
- return false
- }
- if a.Zone != b.Zone {
- return false
- }
- if a.Port != b.Port {
- return false
- }
- return a.IP.Equal(b.IP)
- }
- func discoveredEndpointsAreEqual(a, b map[string]*net.UDPAddr) bool {
- if a == nil && b == nil {
- return true
- }
- if len(a) != len(b) {
- return false
- }
- for k := range a {
- if !udpAddrsEqual(a[k], b[k]) {
- return false
- }
- }
- return true
- }
- func detectMTU(logger log.Logger) uint {
- iface, err := defaultInterface()
- if err != nil {
- _ = level.Warn(logger).Log("msg", "failed to get default interface for MTU detection, using default MTU", "error", err)
- return wireguard.DefaultMTU
- }
- link, err := netlink.LinkByIndex(iface.Index)
- if err != nil {
- _ = level.Warn(logger).Log("msg", "failed to get default interface link for MTU detection, using default MTU", "error", err)
- return wireguard.DefaultMTU
- }
- baseMTU := link.Attrs().MTU
- _ = level.Info(logger).Log("msg", fmt.Sprintf("detected underlay MTU %d on default interface %s", baseMTU, link.Attrs().Name))
- mtu := uint(baseMTU) - wireguard.WireGuardOverhead
- _ = level.Info(logger).Log("msg", fmt.Sprintf("auto-detected WireGuard MTU: %d (underlay %d - overhead %d)", mtu, baseMTU, wireguard.WireGuardOverhead))
- return mtu
- }
- func linkByIndex(index int) (netlink.Link, error) {
- link, err := netlink.LinkByIndex(index)
- if err != nil {
- return nil, fmt.Errorf("failed to get interface: %v", err)
- }
- return link, nil
- }
- // discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
- // Discovered endpionts will never be DNS names, because WireGuard will always resolve them to net.UDPAddr.
- func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wgtypes.Device, logger log.Logger) map[string]*net.UDPAddr {
- natEndpoints := make(map[string]*net.UDPAddr)
- keys := make(map[string]wgtypes.Peer)
- for i := range conf.Peers {
- keys[conf.Peers[i].PublicKey.String()] = conf.Peers[i]
- }
- for _, n := range nodes {
- if peer, ok := keys[n.Key.String()]; ok && n.PersistentKeepalive != time.Duration(0) {
- _ = level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", peer.Endpoint.String() == n.Endpoint.String(), "latest-handshake", peer.LastHandshakeTime)
- // Don't update the endpoint, if there was never any handshake.
- if !peer.LastHandshakeTime.Equal(time.Time{}) {
- natEndpoints[n.Key.String()] = peer.Endpoint
- }
- }
- }
- for _, p := range peers {
- if peer, ok := keys[p.PublicKey.String()]; ok && p.PersistentKeepaliveInterval != nil {
- if !peer.LastHandshakeTime.Equal(time.Time{}) {
- natEndpoints[p.PublicKey.String()] = peer.Endpoint
- }
- }
- }
- _ = level.Debug(logger).Log("msg", "Discovered WireGuard NAT Endpoints", "DiscoveredEndpoints", natEndpoints)
- return natEndpoints
- }
|