mesh.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. // Copyright 2021 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //go:build linux
  15. // +build linux
  16. package mesh
  17. import (
  18. "bytes"
  19. "context"
  20. "fmt"
  21. "net"
  22. "os"
  23. "sync"
  24. "time"
  25. "github.com/go-kit/kit/log"
  26. "github.com/go-kit/kit/log/level"
  27. "github.com/prometheus/client_golang/prometheus"
  28. "github.com/vishvananda/netlink"
  29. "golang.zx2c4.com/wireguard/wgctrl"
  30. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  31. "github.com/squat/kilo/pkg/encapsulation"
  32. "github.com/squat/kilo/pkg/iproute"
  33. "github.com/squat/kilo/pkg/iptables"
  34. "github.com/squat/kilo/pkg/route"
  35. "github.com/squat/kilo/pkg/wireguard"
  36. )
  37. const (
  38. // kiloPath is the directory where Kilo stores its configuration.
  39. kiloPath = "/var/lib/kilo"
  40. // privateKeyPath is the filepath where the WireGuard private key is stored.
  41. privateKeyPath = kiloPath + "/key"
  42. )
  43. // Mesh is able to create Kilo network meshes.
  44. type Mesh struct {
  45. Backend
  46. cleanup bool
  47. cleanUpIface bool
  48. cni bool
  49. cniPath string
  50. enc encapsulation.Encapsulator
  51. externalIP *net.IPNet
  52. granularity Granularity
  53. hostname string
  54. internalIP *net.IPNet
  55. ipTables *iptables.Controller
  56. kiloIface int
  57. kiloIfaceName string
  58. local bool
  59. mtu uint
  60. autoMTU bool
  61. port int
  62. priv wgtypes.Key
  63. privIface int
  64. pub wgtypes.Key
  65. resyncPeriod time.Duration
  66. iptablesForwardRule bool
  67. serviceCIDRs []*net.IPNet
  68. subnet *net.IPNet
  69. table *route.Table
  70. wireGuardIP *net.IPNet
  71. // nodes and peers are mutable fields in the struct
  72. // and need to be guarded.
  73. nodes map[string]*Node
  74. peers map[string]*Peer
  75. mu sync.Mutex
  76. errorCounter *prometheus.CounterVec
  77. leaderGuage prometheus.Gauge
  78. nodesGuage prometheus.Gauge
  79. peersGuage prometheus.Gauge
  80. reconcileCounter prometheus.Counter
  81. logger log.Logger
  82. }
  83. // New returns a new Mesh instance.
  84. func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanup bool, cleanUpIface bool, createIface bool, mtu uint, autoMTU bool, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, allowedInternalCIDRs []*net.IPNet, serviceCIDRs []*net.IPNet, logger log.Logger, registerer prometheus.Registerer) (*Mesh, error) {
  85. if err := os.MkdirAll(kiloPath, 0700); err != nil {
  86. return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
  87. }
  88. privateB, err := os.ReadFile(privateKeyPath)
  89. if err != nil && !os.IsNotExist(err) {
  90. return nil, fmt.Errorf("failed to read private key file: %v", err)
  91. }
  92. privateB = bytes.Trim(privateB, "\n")
  93. private, err := wgtypes.ParseKey(string(privateB))
  94. if err != nil {
  95. _ = level.Warn(logger).Log("msg", "no private key found on disk; generating one now")
  96. if private, err = wgtypes.GeneratePrivateKey(); err != nil {
  97. return nil, err
  98. }
  99. if err := os.WriteFile(privateKeyPath, []byte(private.String()), 0600); err != nil {
  100. return nil, fmt.Errorf("failed to write private key to disk: %v", err)
  101. }
  102. }
  103. public := private.PublicKey()
  104. if err != nil {
  105. return nil, err
  106. }
  107. cniIndex, err := cniDeviceIndex()
  108. if err != nil {
  109. return nil, fmt.Errorf("failed to query netlink for CNI device: %v", err)
  110. }
  111. var kiloIface int
  112. if createIface {
  113. link, err := netlink.LinkByName(iface)
  114. if err != nil {
  115. kiloIface, _, err = wireguard.New(iface, mtu)
  116. if err != nil {
  117. return nil, fmt.Errorf("failed to create WireGuard interface: %v", err)
  118. }
  119. } else {
  120. kiloIface = link.Attrs().Index
  121. }
  122. } else {
  123. link, err := netlink.LinkByName(iface)
  124. if err != nil {
  125. return nil, fmt.Errorf("failed to get interface index: %v", err)
  126. }
  127. kiloIface = link.Attrs().Index
  128. }
  129. privateIP, publicIP, err := getIP(hostname, allowedInternalCIDRs, kiloIface, enc.Index(), cniIndex)
  130. if err != nil {
  131. return nil, fmt.Errorf("failed to find public IP: %v", err)
  132. }
  133. var privIface int
  134. if privateIP != nil {
  135. ifaces, err := interfacesForIP(privateIP)
  136. if err != nil {
  137. return nil, fmt.Errorf("failed to find interface for private IP: %v", err)
  138. }
  139. privIface = ifaces[0].Index
  140. if enc.Strategy() != encapsulation.Never {
  141. if err := enc.Init(privIface); err != nil {
  142. return nil, fmt.Errorf("failed to initialize encapsulator: %v", err)
  143. }
  144. }
  145. _ = level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
  146. } else {
  147. enc = encapsulation.Noop(enc.Strategy())
  148. _ = level.Debug(logger).Log("msg", "running without a private IP address")
  149. }
  150. if autoMTU {
  151. mtu = detectMTU(logger)
  152. }
  153. var externalIP *net.IPNet
  154. if prioritisePrivateAddr && privateIP != nil {
  155. externalIP = privateIP
  156. } else {
  157. externalIP = publicIP
  158. }
  159. _ = level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
  160. ipTables, err := iptables.New(iptables.WithRegisterer(registerer), iptables.WithLogger(log.With(logger, "component", "iptables")), iptables.WithResyncPeriod(resyncPeriod))
  161. if err != nil {
  162. return nil, fmt.Errorf("failed to IP tables controller: %v", err)
  163. }
  164. mesh := Mesh{
  165. Backend: backend,
  166. cleanup: cleanup,
  167. cleanUpIface: cleanUpIface,
  168. cni: cni,
  169. cniPath: cniPath,
  170. enc: enc,
  171. externalIP: externalIP,
  172. granularity: granularity,
  173. hostname: hostname,
  174. internalIP: privateIP,
  175. ipTables: ipTables,
  176. kiloIface: kiloIface,
  177. kiloIfaceName: iface,
  178. nodes: make(map[string]*Node),
  179. peers: make(map[string]*Peer),
  180. port: port,
  181. priv: private,
  182. privIface: privIface,
  183. pub: public,
  184. resyncPeriod: resyncPeriod,
  185. iptablesForwardRule: iptablesForwardRule,
  186. local: local,
  187. mtu: mtu,
  188. autoMTU: autoMTU,
  189. serviceCIDRs: serviceCIDRs,
  190. subnet: subnet,
  191. table: route.NewTable(),
  192. errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
  193. Name: "kilo_errors_total",
  194. Help: "Number of errors that occurred while administering the mesh.",
  195. }, []string{"event"}),
  196. leaderGuage: prometheus.NewGauge(prometheus.GaugeOpts{
  197. Name: "kilo_leader",
  198. Help: "Leadership status of the node.",
  199. }),
  200. nodesGuage: prometheus.NewGauge(prometheus.GaugeOpts{
  201. Name: "kilo_nodes",
  202. Help: "Number of nodes in the mesh.",
  203. }),
  204. peersGuage: prometheus.NewGauge(prometheus.GaugeOpts{
  205. Name: "kilo_peers",
  206. Help: "Number of peers in the mesh.",
  207. }),
  208. reconcileCounter: prometheus.NewCounter(prometheus.CounterOpts{
  209. Name: "kilo_reconciles_total",
  210. Help: "Number of reconciliation attempts.",
  211. }),
  212. logger: logger,
  213. }
  214. registerer.MustRegister(
  215. mesh.errorCounter,
  216. mesh.leaderGuage,
  217. mesh.nodesGuage,
  218. mesh.peersGuage,
  219. mesh.reconcileCounter,
  220. )
  221. return &mesh, nil
  222. }
  223. // Run starts the mesh.
  224. func (m *Mesh) Run(ctx context.Context) error {
  225. if err := m.Nodes().Init(ctx); err != nil {
  226. return fmt.Errorf("failed to initialize node backend: %v", err)
  227. }
  228. // Try to set the CNI config quickly.
  229. if m.cni {
  230. if n, err := m.Nodes().Get(m.hostname); err == nil {
  231. m.nodes[m.hostname] = n
  232. m.updateCNIConfig()
  233. } else {
  234. _ = level.Warn(m.logger).Log("error", fmt.Errorf("failed to get node %q: %v", m.hostname, err))
  235. }
  236. }
  237. if err := m.Peers().Init(ctx); err != nil {
  238. return fmt.Errorf("failed to initialize peer backend: %v", err)
  239. }
  240. ipTablesErrors, err := m.ipTables.Run(ctx.Done())
  241. if err != nil {
  242. return fmt.Errorf("failed to watch for IP tables updates: %v", err)
  243. }
  244. routeErrors, err := m.table.Run(ctx.Done())
  245. if err != nil {
  246. return fmt.Errorf("failed to watch for route table updates: %v", err)
  247. }
  248. go func() {
  249. for {
  250. var err error
  251. select {
  252. case err = <-ipTablesErrors:
  253. case err = <-routeErrors:
  254. case <-ctx.Done():
  255. return
  256. }
  257. if err != nil {
  258. _ = level.Error(m.logger).Log("error", err)
  259. m.errorCounter.WithLabelValues("run").Inc()
  260. }
  261. }
  262. }()
  263. if m.cleanup {
  264. defer m.cleanUp()
  265. }
  266. resync := time.NewTimer(m.resyncPeriod)
  267. checkIn := time.NewTimer(checkInPeriod)
  268. nw := m.Nodes().Watch()
  269. pw := m.Peers().Watch()
  270. var ne *NodeEvent
  271. var pe *PeerEvent
  272. for {
  273. select {
  274. case ne = <-nw:
  275. m.syncNodes(ctx, ne)
  276. case pe = <-pw:
  277. m.syncPeers(pe)
  278. case <-checkIn.C:
  279. m.checkIn(ctx)
  280. checkIn.Reset(checkInPeriod)
  281. case <-resync.C:
  282. if m.cni {
  283. m.updateCNIConfig()
  284. }
  285. m.applyTopology()
  286. resync.Reset(m.resyncPeriod)
  287. case <-ctx.Done():
  288. return nil
  289. }
  290. }
  291. }
  292. func (m *Mesh) syncNodes(ctx context.Context, e *NodeEvent) {
  293. logger := log.With(m.logger, "event", e.Type)
  294. _ = level.Debug(logger).Log("msg", "syncing nodes", "event", e.Type)
  295. if isSelf(m.hostname, e.Node) {
  296. _ = level.Debug(logger).Log("msg", "processing local node", "node", e.Node)
  297. m.handleLocal(ctx, e.Node)
  298. return
  299. }
  300. var diff bool
  301. m.mu.Lock()
  302. if !e.Node.Ready() {
  303. // Trace non ready nodes with their presence in the mesh.
  304. _, ok := m.nodes[e.Node.Name]
  305. _ = level.Debug(logger).Log("msg", "received non ready node", "node", e.Node, "in-mesh", ok)
  306. }
  307. switch e.Type {
  308. case AddEvent:
  309. fallthrough
  310. case UpdateEvent:
  311. if !nodesAreEqual(m.nodes[e.Node.Name], e.Node) {
  312. diff = true
  313. }
  314. // Even if the nodes are the same,
  315. // overwrite the old node to update the timestamp.
  316. m.nodes[e.Node.Name] = e.Node
  317. case DeleteEvent:
  318. delete(m.nodes, e.Node.Name)
  319. diff = true
  320. }
  321. m.mu.Unlock()
  322. if diff {
  323. _ = level.Info(logger).Log("node", e.Node)
  324. m.applyTopology()
  325. }
  326. }
  327. func (m *Mesh) syncPeers(e *PeerEvent) {
  328. logger := log.With(m.logger, "event", e.Type)
  329. _ = level.Debug(logger).Log("msg", "syncing peers", "event", e.Type)
  330. var diff bool
  331. m.mu.Lock()
  332. // Peers are indexed by public key.
  333. key := e.Peer.PublicKey.String()
  334. if !e.Peer.Ready() {
  335. // Trace non ready peer with their presence in the mesh.
  336. _, ok := m.peers[key]
  337. _ = level.Debug(logger).Log("msg", "received non ready peer", "peer", e.Peer, "in-mesh", ok)
  338. }
  339. switch e.Type {
  340. case AddEvent:
  341. fallthrough
  342. case UpdateEvent:
  343. if e.Old != nil && key != e.Old.PublicKey.String() {
  344. delete(m.peers, e.Old.PublicKey.String())
  345. diff = true
  346. }
  347. if !peersAreEqual(m.peers[key], e.Peer) {
  348. m.peers[key] = e.Peer
  349. diff = true
  350. }
  351. case DeleteEvent:
  352. delete(m.peers, key)
  353. diff = true
  354. }
  355. m.mu.Unlock()
  356. if diff {
  357. _ = level.Info(logger).Log("peer", e.Peer)
  358. m.applyTopology()
  359. }
  360. }
  361. // checkIn will try to update the local node's LastSeen timestamp
  362. // in the backend.
  363. func (m *Mesh) checkIn(ctx context.Context) {
  364. m.mu.Lock()
  365. defer m.mu.Unlock()
  366. n := m.nodes[m.hostname]
  367. if n == nil {
  368. _ = level.Debug(m.logger).Log("msg", "no local node found in backend")
  369. return
  370. }
  371. oldTime := n.LastSeen
  372. n.LastSeen = time.Now().Unix()
  373. if err := m.Nodes().Set(ctx, m.hostname, n); err != nil {
  374. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", n)
  375. m.errorCounter.WithLabelValues("checkin").Inc()
  376. // Revert time.
  377. n.LastSeen = oldTime
  378. return
  379. }
  380. _ = level.Debug(m.logger).Log("msg", "successfully checked in local node in backend")
  381. }
  382. func (m *Mesh) handleLocal(ctx context.Context, n *Node) {
  383. // Allow the IPs to be overridden.
  384. if !n.Endpoint.Ready() {
  385. e := wireguard.NewEndpoint(m.externalIP.IP, m.port)
  386. _ = level.Info(m.logger).Log("msg", "overriding endpoint", "node", m.hostname, "old endpoint", n.Endpoint.String(), "new endpoint", e.String())
  387. n.Endpoint = e
  388. }
  389. if n.InternalIP == nil && !n.NoInternalIP {
  390. n.InternalIP = m.internalIP
  391. }
  392. // Compare the given node to the calculated local node.
  393. // Take leader, location, and subnet from the argument, as these
  394. // are not determined by kilo.
  395. local := &Node{
  396. Endpoint: n.Endpoint,
  397. Key: m.pub,
  398. NoInternalIP: n.NoInternalIP,
  399. InternalIP: n.InternalIP,
  400. CNICompatibilityIP: m.enc.CNICompatibilityIP(),
  401. LastSeen: time.Now().Unix(),
  402. Leader: n.Leader,
  403. Location: n.Location,
  404. Name: m.hostname,
  405. PersistentKeepalive: n.PersistentKeepalive,
  406. Subnet: n.Subnet,
  407. WireGuardIP: m.wireGuardIP,
  408. DiscoveredEndpoints: n.DiscoveredEndpoints,
  409. AllowedLocationIPs: n.AllowedLocationIPs,
  410. Granularity: m.granularity,
  411. }
  412. if !nodesAreEqual(n, local) {
  413. _ = level.Debug(m.logger).Log("msg", "local node differs from backend")
  414. if err := m.Nodes().Set(ctx, m.hostname, local); err != nil {
  415. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to set local node: %v", err), "node", local)
  416. m.errorCounter.WithLabelValues("local").Inc()
  417. return
  418. }
  419. _ = level.Debug(m.logger).Log("msg", "successfully reconciled local node against backend")
  420. }
  421. m.mu.Lock()
  422. n = m.nodes[m.hostname]
  423. if n == nil {
  424. n = &Node{}
  425. }
  426. m.mu.Unlock()
  427. if !nodesAreEqual(n, local) {
  428. m.mu.Lock()
  429. m.nodes[local.Name] = local
  430. m.mu.Unlock()
  431. m.applyTopology()
  432. }
  433. }
  434. func (m *Mesh) applyTopology() {
  435. m.reconcileCounter.Inc()
  436. m.mu.Lock()
  437. defer m.mu.Unlock()
  438. // If we can't resolve an endpoint, then fail and retry later.
  439. if err := m.resolveEndpoints(); err != nil {
  440. _ = level.Error(m.logger).Log("error", err)
  441. m.errorCounter.WithLabelValues("apply").Inc()
  442. return
  443. }
  444. // Ensure only ready nodes are considered.
  445. nodes := make(map[string]*Node)
  446. var readyNodes float64
  447. for k := range m.nodes {
  448. m.nodes[k].Granularity = m.granularity
  449. if !m.nodes[k].Ready() {
  450. continue
  451. }
  452. // Make it point to the node without copy.
  453. nodes[k] = m.nodes[k]
  454. readyNodes++
  455. }
  456. // Ensure only ready nodes are considered.
  457. peers := make(map[string]*Peer)
  458. var readyPeers float64
  459. for k := range m.peers {
  460. if !m.peers[k].Ready() {
  461. continue
  462. }
  463. // Make it point the peer without copy.
  464. peers[k] = m.peers[k]
  465. readyPeers++
  466. }
  467. m.nodesGuage.Set(readyNodes)
  468. m.peersGuage.Set(readyPeers)
  469. // We cannot do anything with the topology until the local node is available.
  470. if nodes[m.hostname] == nil {
  471. return
  472. }
  473. // Re-detect MTU if auto mode is enabled.
  474. if m.autoMTU {
  475. m.mtu = detectMTU(m.logger)
  476. }
  477. // Ensure the WireGuard interface has the correct MTU.
  478. if err := wireguard.SetMTU(m.kiloIface, m.mtu); err != nil {
  479. _ = level.Error(m.logger).Log("error", fmt.Errorf("failed to set MTU on WireGuard interface: %v", err))
  480. m.errorCounter.WithLabelValues("apply").Inc()
  481. return
  482. }
  483. // Find the Kilo interface name.
  484. link, err := linkByIndex(m.kiloIface)
  485. if err != nil {
  486. _ = level.Error(m.logger).Log("error", err)
  487. m.errorCounter.WithLabelValues("apply").Inc()
  488. return
  489. }
  490. wgClient, err := wgctrl.New()
  491. if err != nil {
  492. _ = level.Error(m.logger).Log("error", err)
  493. m.errorCounter.WithLabelValues("apply").Inc()
  494. return
  495. }
  496. defer func() { _ = wgClient.Close() }()
  497. // wgDevice is the current configuration of the wg interface.
  498. wgDevice, err := wgClient.Device(m.kiloIfaceName)
  499. if err != nil {
  500. _ = level.Error(m.logger).Log("error", err)
  501. m.errorCounter.WithLabelValues("apply").Inc()
  502. return
  503. }
  504. natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
  505. nodes[m.hostname].DiscoveredEndpoints = natEndpoints
  506. t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, m.serviceCIDRs, nodes[m.hostname].PersistentKeepalive, m.logger)
  507. if err != nil {
  508. _ = level.Error(m.logger).Log("error", err)
  509. m.errorCounter.WithLabelValues("apply").Inc()
  510. return
  511. }
  512. // Update the node's WireGuard IP.
  513. if t.leader {
  514. m.wireGuardIP = t.wireGuardCIDR
  515. } else {
  516. m.wireGuardIP = nil
  517. }
  518. ipRules := t.Rules(m.cni, m.iptablesForwardRule)
  519. // If we are handling local routes, ensure the local
  520. // tunnel has an IP address and IPIP traffic is allowed.
  521. if m.enc.Strategy() != encapsulation.Never && m.local {
  522. var cidrs []*net.IPNet
  523. for _, s := range t.segments {
  524. // If the location prefix is not logicalLocation, but nodeLocation,
  525. // we don't need to set any extra rules for encapsulation anyways
  526. // because traffic will go over WireGuard.
  527. if s.location == logicalLocationPrefix+nodes[m.hostname].Location {
  528. for i := range s.privateIPs {
  529. cidrs = append(cidrs, oneAddressCIDR(s.privateIPs[i]))
  530. }
  531. break
  532. }
  533. }
  534. encIpRules := m.enc.Rules(cidrs)
  535. ipRules = encIpRules.AppendRuleSet(ipRules)
  536. // If we are handling local routes, ensure the local
  537. // tunnel has an IP address.
  538. if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
  539. _ = level.Error(m.logger).Log("error", err)
  540. m.errorCounter.WithLabelValues("apply").Inc()
  541. return
  542. }
  543. }
  544. if err := m.ipTables.Set(ipRules); err != nil {
  545. _ = level.Error(m.logger).Log("error", err)
  546. m.errorCounter.WithLabelValues("apply").Inc()
  547. return
  548. }
  549. if t.leader {
  550. m.leaderGuage.Set(1)
  551. if err := iproute.SetAddress(m.kiloIface, t.wireGuardCIDR); err != nil {
  552. _ = level.Error(m.logger).Log("error", err)
  553. m.errorCounter.WithLabelValues("apply").Inc()
  554. return
  555. }
  556. // Setting the WireGuard configuration interrupts existing connections
  557. // so only set the configuration if it has changed.
  558. conf := t.Conf()
  559. equal, diff := conf.Equal(wgDevice)
  560. if !equal {
  561. _ = level.Info(m.logger).Log("msg", "WireGuard configurations are different", "diff", diff)
  562. _ = level.Debug(m.logger).Log("msg", "changing wg config", "config", conf.WGConfig())
  563. if err := wgClient.ConfigureDevice(m.kiloIfaceName, conf.WGConfig()); err != nil {
  564. _ = level.Error(m.logger).Log("error", err)
  565. m.errorCounter.WithLabelValues("apply").Inc()
  566. return
  567. }
  568. }
  569. if err := iproute.Set(m.kiloIface, true); err != nil {
  570. _ = level.Error(m.logger).Log("error", err)
  571. m.errorCounter.WithLabelValues("apply").Inc()
  572. return
  573. }
  574. } else {
  575. m.leaderGuage.Set(0)
  576. _ = level.Debug(m.logger).Log("msg", "local node is not the leader")
  577. if err := iproute.Set(m.kiloIface, false); err != nil {
  578. _ = level.Error(m.logger).Log("error", err)
  579. m.errorCounter.WithLabelValues("apply").Inc()
  580. return
  581. }
  582. }
  583. // We need to add routes last since they may depend
  584. // on the WireGuard interface.
  585. routes, rules := t.Routes(link.Attrs().Name, m.kiloIface, m.privIface, m.enc.Index(), m.local, m.enc)
  586. if err := m.table.Set(routes, rules); err != nil {
  587. _ = level.Error(m.logger).Log("error", err)
  588. m.errorCounter.WithLabelValues("apply").Inc()
  589. }
  590. }
  591. func (m *Mesh) cleanUp() {
  592. if err := m.ipTables.CleanUp(); err != nil {
  593. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up IP tables: %v", err))
  594. m.errorCounter.WithLabelValues("cleanUp").Inc()
  595. }
  596. if err := m.table.CleanUp(); err != nil {
  597. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up routes: %v", err))
  598. m.errorCounter.WithLabelValues("cleanUp").Inc()
  599. }
  600. if m.cleanUpIface {
  601. if err := iproute.RemoveInterface(m.kiloIface); err != nil {
  602. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to remove WireGuard interface: %v", err))
  603. m.errorCounter.WithLabelValues("cleanUp").Inc()
  604. }
  605. }
  606. {
  607. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  608. defer cancel()
  609. if err := m.Nodes().CleanUp(ctx, m.hostname); err != nil {
  610. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up node backend: %v", err))
  611. m.errorCounter.WithLabelValues("cleanUp").Inc()
  612. }
  613. }
  614. {
  615. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  616. defer cancel()
  617. if err := m.Peers().CleanUp(ctx, m.hostname); err != nil {
  618. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
  619. m.errorCounter.WithLabelValues("cleanUp").Inc()
  620. }
  621. }
  622. if err := m.enc.CleanUp(); err != nil {
  623. _ = level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulator: %v", err))
  624. m.errorCounter.WithLabelValues("cleanUp").Inc()
  625. }
  626. }
  627. func (m *Mesh) resolveEndpoints() error {
  628. for k := range m.nodes {
  629. // Skip unready nodes, since they will not be used
  630. // in the topology anyways.
  631. if !m.nodes[k].Ready() {
  632. continue
  633. }
  634. // Resolve the Endpoint
  635. if _, err := m.nodes[k].Endpoint.UDPAddr(true); err != nil {
  636. return err
  637. }
  638. }
  639. for k := range m.peers {
  640. // Skip unready peers, since they will not be used
  641. // in the topology anyways.
  642. if !m.peers[k].Ready() {
  643. continue
  644. }
  645. // Peers may have nil endpoints.
  646. if !m.peers[k].Endpoint.Ready() {
  647. continue
  648. }
  649. if _, err := m.peers[k].Endpoint.UDPAddr(true); err != nil {
  650. return err
  651. }
  652. }
  653. return nil
  654. }
  655. func isSelf(hostname string, node *Node) bool {
  656. return node != nil && node.Name == hostname
  657. }
  658. func nodesAreEqual(a, b *Node) bool {
  659. if (a != nil) != (b != nil) {
  660. return false
  661. }
  662. if a == b {
  663. return true
  664. }
  665. // Check the DNS name first since this package
  666. // is doing the DNS resolution.
  667. if !a.Endpoint.Equal(b.Endpoint, true) {
  668. return false
  669. }
  670. // Ignore LastSeen when comparing equality we want to check if the nodes are
  671. // equivalent. However, we do want to check if LastSeen has transitioned
  672. // between valid and invalid.
  673. return a.Key.String() == b.Key.String() &&
  674. ipNetsEqual(a.WireGuardIP, b.WireGuardIP) &&
  675. ipNetsEqual(a.InternalIP, b.InternalIP) &&
  676. ipNetsEqual(a.CNICompatibilityIP, b.CNICompatibilityIP) &&
  677. a.Leader == b.Leader &&
  678. a.Location == b.Location &&
  679. a.Name == b.Name &&
  680. subnetsEqual(a.Subnet, b.Subnet) &&
  681. a.Ready() == b.Ready() &&
  682. a.PersistentKeepalive == b.PersistentKeepalive &&
  683. discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) &&
  684. ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs) &&
  685. a.Granularity == b.Granularity
  686. }
  687. func peersAreEqual(a, b *Peer) bool {
  688. if (a != nil) != (b != nil) {
  689. return false
  690. }
  691. if a == b {
  692. return true
  693. }
  694. // Check the DNS name first since this package
  695. // is doing the DNS resolution.
  696. if !a.Endpoint.Equal(b.Endpoint, true) {
  697. return false
  698. }
  699. if len(a.AllowedIPs) != len(b.AllowedIPs) {
  700. return false
  701. }
  702. for i := range a.AllowedIPs {
  703. if !ipNetsEqual(&a.AllowedIPs[i], &b.AllowedIPs[i]) {
  704. return false
  705. }
  706. }
  707. return a.PublicKey.String() == b.PublicKey.String() &&
  708. (a.PresharedKey == nil) == (b.PresharedKey == nil) &&
  709. (a.PresharedKey == nil || a.PresharedKey.String() == b.PresharedKey.String()) &&
  710. (a.PersistentKeepaliveInterval == nil) == (b.PersistentKeepaliveInterval == nil) &&
  711. (a.PersistentKeepaliveInterval == nil || *a.PersistentKeepaliveInterval == *b.PersistentKeepaliveInterval)
  712. }
  713. func ipNetsEqual(a, b *net.IPNet) bool {
  714. if a == nil && b == nil {
  715. return true
  716. }
  717. if (a != nil) != (b != nil) {
  718. return false
  719. }
  720. if a.Mask.String() != b.Mask.String() {
  721. return false
  722. }
  723. return a.IP.Equal(b.IP)
  724. }
  725. func ipNetSlicesEqual(a, b []net.IPNet) bool {
  726. if len(a) != len(b) {
  727. return false
  728. }
  729. for i := range a {
  730. if !ipNetsEqual(&a[i], &b[i]) {
  731. return false
  732. }
  733. }
  734. return true
  735. }
  736. func subnetsEqual(a, b *net.IPNet) bool {
  737. if a == nil && b == nil {
  738. return true
  739. }
  740. if (a != nil) != (b != nil) {
  741. return false
  742. }
  743. if a.Mask.String() != b.Mask.String() {
  744. return false
  745. }
  746. if !a.Contains(b.IP) {
  747. return false
  748. }
  749. if !b.Contains(a.IP) {
  750. return false
  751. }
  752. return true
  753. }
  754. func udpAddrsEqual(a, b *net.UDPAddr) bool {
  755. if a == nil && b == nil {
  756. return true
  757. }
  758. if (a != nil) != (b != nil) {
  759. return false
  760. }
  761. if a.Zone != b.Zone {
  762. return false
  763. }
  764. if a.Port != b.Port {
  765. return false
  766. }
  767. return a.IP.Equal(b.IP)
  768. }
  769. func discoveredEndpointsAreEqual(a, b map[string]*net.UDPAddr) bool {
  770. if a == nil && b == nil {
  771. return true
  772. }
  773. if len(a) != len(b) {
  774. return false
  775. }
  776. for k := range a {
  777. if !udpAddrsEqual(a[k], b[k]) {
  778. return false
  779. }
  780. }
  781. return true
  782. }
  783. func detectMTU(logger log.Logger) uint {
  784. iface, err := defaultInterface()
  785. if err != nil {
  786. _ = level.Warn(logger).Log("msg", "failed to get default interface for MTU detection, using default MTU", "error", err)
  787. return wireguard.DefaultMTU
  788. }
  789. link, err := netlink.LinkByIndex(iface.Index)
  790. if err != nil {
  791. _ = level.Warn(logger).Log("msg", "failed to get default interface link for MTU detection, using default MTU", "error", err)
  792. return wireguard.DefaultMTU
  793. }
  794. baseMTU := link.Attrs().MTU
  795. _ = level.Info(logger).Log("msg", fmt.Sprintf("detected underlay MTU %d on default interface %s", baseMTU, link.Attrs().Name))
  796. mtu := uint(baseMTU) - wireguard.WireGuardOverhead
  797. _ = level.Info(logger).Log("msg", fmt.Sprintf("auto-detected WireGuard MTU: %d (underlay %d - overhead %d)", mtu, baseMTU, wireguard.WireGuardOverhead))
  798. return mtu
  799. }
  800. func linkByIndex(index int) (netlink.Link, error) {
  801. link, err := netlink.LinkByIndex(index)
  802. if err != nil {
  803. return nil, fmt.Errorf("failed to get interface: %v", err)
  804. }
  805. return link, nil
  806. }
  807. // discoverNATEndpoints uses the node's WireGuard configuration to returns a list of the most recently discovered endpoints for all nodes and peers behind NAT so that they can roam.
  808. // Discovered endpionts will never be DNS names, because WireGuard will always resolve them to net.UDPAddr.
  809. func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *wgtypes.Device, logger log.Logger) map[string]*net.UDPAddr {
  810. natEndpoints := make(map[string]*net.UDPAddr)
  811. keys := make(map[string]wgtypes.Peer)
  812. for i := range conf.Peers {
  813. keys[conf.Peers[i].PublicKey.String()] = conf.Peers[i]
  814. }
  815. for _, n := range nodes {
  816. if peer, ok := keys[n.Key.String()]; ok && n.PersistentKeepalive != time.Duration(0) {
  817. _ = level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", peer.Endpoint.String() == n.Endpoint.String(), "latest-handshake", peer.LastHandshakeTime)
  818. // Don't update the endpoint, if there was never any handshake.
  819. if !peer.LastHandshakeTime.Equal(time.Time{}) {
  820. natEndpoints[n.Key.String()] = peer.Endpoint
  821. }
  822. }
  823. }
  824. for _, p := range peers {
  825. if peer, ok := keys[p.PublicKey.String()]; ok && p.PersistentKeepaliveInterval != nil {
  826. if !peer.LastHandshakeTime.Equal(time.Time{}) {
  827. natEndpoints[p.PublicKey.String()] = peer.Endpoint
  828. }
  829. }
  830. }
  831. _ = level.Debug(logger).Log("msg", "Discovered WireGuard NAT Endpoints", "DiscoveredEndpoints", natEndpoints)
  832. return natEndpoints
  833. }