topology.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. // Copyright 2021 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package mesh
  15. import (
  16. "errors"
  17. "net"
  18. "sort"
  19. "time"
  20. "github.com/go-kit/kit/log"
  21. "github.com/go-kit/kit/log/level"
  22. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  23. "github.com/squat/kilo/pkg/wireguard"
  24. )
  25. const (
  26. logicalLocationPrefix = "location:"
  27. nodeLocationPrefix = "node:"
  28. )
  29. // Topology represents the logical structure of the overlay network.
  30. type Topology struct {
  31. // key is the private key of the node creating the topology.
  32. key wgtypes.Key
  33. port int
  34. // Location is the logical location of the local host.
  35. location string
  36. segments []*segment
  37. peers []*Peer
  38. // hostname is the hostname of the local host.
  39. hostname string
  40. // leader represents whether or not the local host
  41. // is the segment leader.
  42. leader bool
  43. // persistentKeepalive is the interval in seconds of the emission
  44. // of keepalive packets by the local node to its peers.
  45. persistentKeepalive time.Duration
  46. // privateIP is the private IP address of the local node.
  47. privateIP *net.IPNet
  48. // subnet is the Pod subnet of the local node.
  49. subnet *net.IPNet
  50. // wireGuardCIDR is the allocated CIDR of the WireGuard
  51. // interface of the local node within the Kilo subnet.
  52. // If the local node is not the leader of a location, then
  53. // the IP is the 0th address in the subnet, i.e. the CIDR
  54. // is equal to the Kilo subnet.
  55. wireGuardCIDR *net.IPNet
  56. // discoveredEndpoints is the updated map of valid discovered Endpoints
  57. discoveredEndpoints map[string]*net.UDPAddr
  58. logger log.Logger
  59. }
  60. // segment represents one logical unit in the topology that is united by one common WireGuard IP.
  61. type segment struct {
  62. allowedIPs []net.IPNet
  63. endpoint *wireguard.Endpoint
  64. key wgtypes.Key
  65. persistentKeepalive time.Duration
  66. // Location is the logical location of this segment.
  67. location string
  68. // cidrs is a slice of subnets of all peers in the segment.
  69. cidrs []*net.IPNet
  70. // hostnames is a slice of the hostnames of the peers in the segment.
  71. hostnames []string
  72. // leader is the index of the leader of the segment.
  73. leader int
  74. // privateIPs is a slice of private IPs of all peers in the segment.
  75. privateIPs []net.IP
  76. // wireGuardIP is the allocated IP address of the WireGuard
  77. // interface on the leader of the segment.
  78. wireGuardIP net.IP
  79. // allowedLocationIPs are not part of the cluster and are not peers.
  80. // They are directly routable from nodes within the segment.
  81. // A classic example is a printer that ought to be routable from other locations.
  82. allowedLocationIPs []net.IPNet
  83. }
  84. // NewTopology creates a new Topology struct from a given set of nodes and peers.
  85. func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
  86. if logger == nil {
  87. logger = log.NewNopLogger()
  88. }
  89. topoMap := make(map[string][]*Node)
  90. for _, node := range nodes {
  91. var location string
  92. switch granularity {
  93. case LogicalGranularity:
  94. location = logicalLocationPrefix + node.Location
  95. // Put node in a different location, if no private
  96. // IP was found.
  97. if node.InternalIP == nil {
  98. location = nodeLocationPrefix + node.Name
  99. }
  100. case FullGranularity:
  101. location = nodeLocationPrefix + node.Name
  102. }
  103. topoMap[location] = append(topoMap[location], node)
  104. }
  105. var localLocation string
  106. switch granularity {
  107. case LogicalGranularity:
  108. localLocation = logicalLocationPrefix + nodes[hostname].Location
  109. if nodes[hostname].InternalIP == nil {
  110. localLocation = nodeLocationPrefix + hostname
  111. }
  112. case FullGranularity:
  113. localLocation = nodeLocationPrefix + hostname
  114. }
  115. t := Topology{
  116. key: key,
  117. port: port,
  118. hostname: hostname,
  119. location: localLocation,
  120. persistentKeepalive: persistentKeepalive,
  121. privateIP: nodes[hostname].InternalIP,
  122. subnet: nodes[hostname].Subnet,
  123. wireGuardCIDR: subnet,
  124. discoveredEndpoints: make(map[string]*net.UDPAddr),
  125. logger: logger,
  126. }
  127. for location := range topoMap {
  128. // Sort the location so the result is stable.
  129. sort.Slice(topoMap[location], func(i, j int) bool {
  130. return topoMap[location][i].Name < topoMap[location][j].Name
  131. })
  132. leader := findLeader(topoMap[location])
  133. if location == localLocation && topoMap[location][leader].Name == hostname {
  134. t.leader = true
  135. }
  136. var allowedIPs []net.IPNet
  137. allowedLocationIPsMap := make(map[string]struct{})
  138. var allowedLocationIPs []net.IPNet
  139. var cidrs []*net.IPNet
  140. var hostnames []string
  141. var privateIPs []net.IP
  142. for _, node := range topoMap[location] {
  143. // Allowed IPs should include:
  144. // - the node's allocated subnet
  145. // - the node's WireGuard IP
  146. // - the node's internal IP
  147. // - IPs that were specified by the allowed-location-ips annotation
  148. if node.Subnet != nil {
  149. allowedIPs = append(allowedIPs, *node.Subnet)
  150. }
  151. for _, ip := range node.AllowedLocationIPs {
  152. if _, ok := allowedLocationIPsMap[ip.String()]; !ok {
  153. allowedLocationIPs = append(allowedLocationIPs, ip)
  154. allowedLocationIPsMap[ip.String()] = struct{}{}
  155. }
  156. }
  157. if node.InternalIP != nil {
  158. allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP))
  159. privateIPs = append(privateIPs, node.InternalIP.IP)
  160. }
  161. cidrs = append(cidrs, node.Subnet)
  162. hostnames = append(hostnames, node.Name)
  163. }
  164. // The sorting has no function, but makes testing easier.
  165. sort.Slice(allowedLocationIPs, func(i, j int) bool {
  166. return allowedLocationIPs[i].String() < allowedLocationIPs[j].String()
  167. })
  168. t.segments = append(t.segments, &segment{
  169. allowedIPs: allowedIPs,
  170. endpoint: topoMap[location][leader].Endpoint,
  171. key: topoMap[location][leader].Key,
  172. persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
  173. location: location,
  174. cidrs: cidrs,
  175. hostnames: hostnames,
  176. leader: leader,
  177. privateIPs: privateIPs,
  178. allowedLocationIPs: allowedLocationIPs,
  179. })
  180. level.Debug(t.logger).Log("msg", "generated segment", "location", location, "allowedIPs", allowedIPs, "endpoint", topoMap[location][leader].Endpoint, "cidrs", cidrs, "hostnames", hostnames, "leader", leader, "privateIPs", privateIPs, "allowedLocationIPs", allowedLocationIPs)
  181. }
  182. // Sort the Topology segments so the result is stable.
  183. sort.Slice(t.segments, func(i, j int) bool {
  184. return t.segments[i].location < t.segments[j].location
  185. })
  186. for _, peer := range peers {
  187. t.peers = append(t.peers, peer)
  188. }
  189. // Sort the Topology peers so the result is stable.
  190. sort.Slice(t.peers, func(i, j int) bool {
  191. return t.peers[i].Name < t.peers[j].Name
  192. })
  193. // We need to defensively deduplicate peer allowed IPs. If two peers claim the same IP,
  194. // the WireGuard configuration could flap, causing the interface to churn.
  195. t.peers = deduplicatePeerIPs(t.peers)
  196. // Copy the host node DiscoveredEndpoints in the topology as a starting point.
  197. for key := range nodes[hostname].DiscoveredEndpoints {
  198. t.discoveredEndpoints[key] = nodes[hostname].DiscoveredEndpoints[key]
  199. }
  200. // Allocate IPs to the segment leaders in a stable, coordination-free manner.
  201. a := newAllocator(*subnet)
  202. for _, segment := range t.segments {
  203. ipNet := a.next()
  204. if ipNet == nil {
  205. return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
  206. }
  207. segment.wireGuardIP = ipNet.IP
  208. segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP))
  209. if t.leader && segment.location == t.location {
  210. t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
  211. }
  212. // Now that the topology is ordered, update the discoveredEndpoints map
  213. // add new ones by going through the ordered topology: segments, nodes
  214. for _, node := range topoMap[segment.location] {
  215. for key := range node.DiscoveredEndpoints {
  216. if _, ok := t.discoveredEndpoints[key]; !ok {
  217. t.discoveredEndpoints[key] = node.DiscoveredEndpoints[key]
  218. }
  219. }
  220. }
  221. // Check for intersecting IPs in allowed location IPs
  222. segment.allowedLocationIPs = t.filterAllowedLocationIPs(segment.allowedLocationIPs, segment.location)
  223. }
  224. level.Debug(t.logger).Log("msg", "generated topology", "location", t.location, "hostname", t.hostname, "wireGuardIP", t.wireGuardCIDR, "privateIP", t.privateIP, "subnet", t.subnet, "leader", t.leader)
  225. return &t, nil
  226. }
  227. func intersect(n1, n2 net.IPNet) bool {
  228. return n1.Contains(n2.IP) || n2.Contains(n1.IP)
  229. }
  230. func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) {
  231. CheckIPs:
  232. for _, ip := range ips {
  233. for _, s := range t.segments {
  234. // Check if allowed location IPs are also allowed in other locations.
  235. if location != s.location {
  236. for _, i := range s.allowedLocationIPs {
  237. if intersect(ip, i) {
  238. level.Warn(t.logger).Log("msg", "overlapping allowed location IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
  239. continue CheckIPs
  240. }
  241. }
  242. }
  243. // Check if allowed location IPs intersect with the allowed IPs.
  244. for _, i := range s.allowedIPs {
  245. if intersect(ip, i) {
  246. level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with allowed IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
  247. continue CheckIPs
  248. }
  249. }
  250. // Check if allowed location IPs intersect with the private IPs of the segment.
  251. for _, i := range s.privateIPs {
  252. if ip.Contains(i) {
  253. level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with privateIP", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
  254. continue CheckIPs
  255. }
  256. }
  257. }
  258. // Check if allowed location IPs intersect with allowed IPs of peers.
  259. for _, p := range t.peers {
  260. for _, i := range p.AllowedIPs {
  261. if intersect(ip, i) {
  262. level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with peer IPnet", "IP", ip.String(), "IP2", i.String(), "peer", p.Name)
  263. continue CheckIPs
  264. }
  265. }
  266. }
  267. ret = append(ret, ip)
  268. }
  269. return
  270. }
  271. func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint {
  272. // Do not update non-nat peers
  273. if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
  274. return endpoint
  275. }
  276. e, ok := t.discoveredEndpoints[key.String()]
  277. if ok {
  278. return wireguard.NewEndpointFromUDPAddr(e)
  279. }
  280. return endpoint
  281. }
  282. // Conf generates a WireGuard configuration file for a given Topology.
  283. func (t *Topology) Conf() *wireguard.Conf {
  284. c := &wireguard.Conf{
  285. Config: wgtypes.Config{
  286. PrivateKey: &t.key,
  287. ListenPort: &t.port,
  288. ReplacePeers: true,
  289. },
  290. }
  291. for _, s := range t.segments {
  292. if s.location == t.location {
  293. continue
  294. }
  295. peer := wireguard.Peer{
  296. PeerConfig: wgtypes.PeerConfig{
  297. AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
  298. PersistentKeepaliveInterval: &t.persistentKeepalive,
  299. PublicKey: s.key,
  300. ReplaceAllowedIPs: true,
  301. },
  302. Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
  303. }
  304. c.Peers = append(c.Peers, peer)
  305. }
  306. for _, p := range t.peers {
  307. peer := wireguard.Peer{
  308. PeerConfig: wgtypes.PeerConfig{
  309. AllowedIPs: p.AllowedIPs,
  310. PersistentKeepaliveInterval: &t.persistentKeepalive,
  311. PresharedKey: p.PresharedKey,
  312. PublicKey: p.PublicKey,
  313. ReplaceAllowedIPs: true,
  314. },
  315. Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
  316. }
  317. c.Peers = append(c.Peers, peer)
  318. }
  319. return c
  320. }
  321. // AsPeer generates the WireGuard peer configuration for the local location of the given Topology.
  322. // This configuration can be used to configure this location as a peer of another WireGuard interface.
  323. func (t *Topology) AsPeer() *wireguard.Peer {
  324. for _, s := range t.segments {
  325. if s.location != t.location {
  326. continue
  327. }
  328. p := &wireguard.Peer{
  329. PeerConfig: wgtypes.PeerConfig{
  330. AllowedIPs: s.allowedIPs,
  331. PublicKey: s.key,
  332. },
  333. Endpoint: s.endpoint,
  334. }
  335. return p
  336. }
  337. return nil
  338. }
  339. // PeerConf generates a WireGuard configuration file for a given peer in a Topology.
  340. func (t *Topology) PeerConf(name string) *wireguard.Conf {
  341. var pka *time.Duration
  342. var psk *wgtypes.Key
  343. for i := range t.peers {
  344. if t.peers[i].Name == name {
  345. pka = t.peers[i].PersistentKeepaliveInterval
  346. psk = t.peers[i].PresharedKey
  347. break
  348. }
  349. }
  350. c := &wireguard.Conf{}
  351. for _, s := range t.segments {
  352. peer := wireguard.Peer{
  353. PeerConfig: wgtypes.PeerConfig{
  354. AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
  355. PersistentKeepaliveInterval: pka,
  356. PresharedKey: psk,
  357. PublicKey: s.key,
  358. },
  359. Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
  360. }
  361. c.Peers = append(c.Peers, peer)
  362. }
  363. for i := range t.peers {
  364. if t.peers[i].Name == name {
  365. continue
  366. }
  367. peer := wireguard.Peer{
  368. PeerConfig: wgtypes.PeerConfig{
  369. AllowedIPs: t.peers[i].AllowedIPs,
  370. PersistentKeepaliveInterval: pka,
  371. PublicKey: t.peers[i].PublicKey,
  372. },
  373. Endpoint: t.updateEndpoint(t.peers[i].Endpoint, t.peers[i].PublicKey, t.peers[i].PersistentKeepaliveInterval),
  374. }
  375. c.Peers = append(c.Peers, peer)
  376. }
  377. return c
  378. }
  379. // oneAddressCIDR takes an IP address and returns a CIDR
  380. // that contains only that address.
  381. func oneAddressCIDR(ip net.IP) *net.IPNet {
  382. return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}
  383. }
  384. // findLeader selects a leader for the nodes in a segment;
  385. // it will select the first node that says it should lead
  386. // or the first node in the segment if none have volunteered,
  387. // always preferring those with a public external IP address,
  388. func findLeader(nodes []*Node) int {
  389. var leaders, public []int
  390. for i := range nodes {
  391. if nodes[i].Leader {
  392. if isPublic(nodes[i].Endpoint.IP()) {
  393. return i
  394. }
  395. leaders = append(leaders, i)
  396. }
  397. if nodes[i].Endpoint.IP() != nil && isPublic(nodes[i].Endpoint.IP()) {
  398. public = append(public, i)
  399. }
  400. }
  401. if len(leaders) != 0 {
  402. return leaders[0]
  403. }
  404. if len(public) != 0 {
  405. return public[0]
  406. }
  407. return 0
  408. }
  409. func deduplicatePeerIPs(peers []*Peer) []*Peer {
  410. ps := make([]*Peer, len(peers))
  411. ips := make(map[string]struct{})
  412. for i, peer := range peers {
  413. p := Peer{
  414. Name: peer.Name,
  415. Peer: wireguard.Peer{
  416. PeerConfig: wgtypes.PeerConfig{
  417. PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
  418. PresharedKey: peer.PresharedKey,
  419. PublicKey: peer.PublicKey,
  420. },
  421. Endpoint: peer.Endpoint,
  422. },
  423. }
  424. for _, ip := range peer.AllowedIPs {
  425. if _, ok := ips[ip.String()]; ok {
  426. continue
  427. }
  428. p.AllowedIPs = append(p.AllowedIPs, ip)
  429. ips[ip.String()] = struct{}{}
  430. }
  431. ps[i] = &p
  432. }
  433. return ps
  434. }