topology.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. // Copyright 2021 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package mesh
  15. import (
  16. "errors"
  17. "net"
  18. "sort"
  19. "time"
  20. "github.com/go-kit/kit/log"
  21. "github.com/go-kit/kit/log/level"
  22. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  23. "github.com/squat/kilo/pkg/wireguard"
  24. )
  25. const (
  26. logicalLocationPrefix = "location:"
  27. nodeLocationPrefix = "node:"
  28. )
  29. // Topology represents the logical structure of the overlay network.
  30. type Topology struct {
  31. // key is the private key of the node creating the topology.
  32. key wgtypes.Key
  33. port int
  34. // Location is the logical location of the local host.
  35. location string
  36. segments []*segment
  37. peers []*Peer
  38. // hostname is the hostname of the local host.
  39. hostname string
  40. // leader represents whether or not the local host
  41. // is the segment leader.
  42. leader bool
  43. // persistentKeepalive is the interval in seconds of the emission
  44. // of keepalive packets by the local node to its peers.
  45. persistentKeepalive time.Duration
  46. // privateIP is the private IP address of the local node.
  47. privateIP *net.IPNet
  48. // subnet is the Pod subnet of the local node.
  49. subnet *net.IPNet
  50. // wireGuardCIDR is the allocated CIDR of the WireGuard
  51. // interface of the local node within the Kilo subnet.
  52. // If the local node is not the leader of a location, then
  53. // the IP is the 0th address in the subnet, i.e. the CIDR
  54. // is equal to the Kilo subnet.
  55. wireGuardCIDR *net.IPNet
  56. // serviceCIDRs are the known service CIDRs of the Kubernetes cluster.
  57. // They are not strictly needed, however if they are known,
  58. // then the topology can avoid masquerading packets destined to service IPs.
  59. serviceCIDRs []*net.IPNet
  60. // discoveredEndpoints is the updated map of valid discovered Endpoints
  61. discoveredEndpoints map[string]*net.UDPAddr
  62. logger log.Logger
  63. }
  64. // segment represents one logical unit in the topology that is united by one common WireGuard IP.
  65. type segment struct {
  66. allowedIPs []net.IPNet
  67. endpoint *wireguard.Endpoint
  68. key wgtypes.Key
  69. persistentKeepalive time.Duration
  70. // Location is the logical location of this segment.
  71. location string
  72. // cidrs is a slice of subnets of all peers in the segment.
  73. cidrs []*net.IPNet
  74. // hostnames is a slice of the hostnames of the peers in the segment.
  75. hostnames []string
  76. // leader is the index of the leader of the segment.
  77. leader int
  78. // privateIPs is a slice of private IPs of all peers in the segment.
  79. privateIPs []net.IP
  80. // wireGuardIP is the allocated IP address of the WireGuard
  81. // interface on the leader of the segment.
  82. wireGuardIP net.IP
  83. // allowedLocationIPs are not part of the cluster and are not peers.
  84. // They are directly routable from nodes within the segment.
  85. // A classic example is a printer that ought to be routable from other locations.
  86. allowedLocationIPs []net.IPNet
  87. }
  88. // NewTopology creates a new Topology struct from a given set of nodes and peers.
  89. func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, serviceCIDRs []*net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
  90. if logger == nil {
  91. logger = log.NewNopLogger()
  92. }
  93. topoMap := make(map[string][]*Node)
  94. for _, node := range nodes {
  95. var location string
  96. switch granularity {
  97. case LogicalGranularity:
  98. location = logicalLocationPrefix + node.Location
  99. // Put node in a different location, if no private
  100. // IP was found.
  101. if node.InternalIP == nil {
  102. location = nodeLocationPrefix + node.Name
  103. }
  104. case FullGranularity:
  105. location = nodeLocationPrefix + node.Name
  106. }
  107. topoMap[location] = append(topoMap[location], node)
  108. }
  109. var localLocation string
  110. switch granularity {
  111. case LogicalGranularity:
  112. localLocation = logicalLocationPrefix + nodes[hostname].Location
  113. if nodes[hostname].InternalIP == nil {
  114. localLocation = nodeLocationPrefix + hostname
  115. }
  116. case FullGranularity:
  117. localLocation = nodeLocationPrefix + hostname
  118. }
  119. t := Topology{
  120. key: key,
  121. port: port,
  122. hostname: hostname,
  123. location: localLocation,
  124. persistentKeepalive: persistentKeepalive,
  125. privateIP: nodes[hostname].InternalIP,
  126. subnet: nodes[hostname].Subnet,
  127. wireGuardCIDR: subnet,
  128. serviceCIDRs: serviceCIDRs,
  129. discoveredEndpoints: make(map[string]*net.UDPAddr),
  130. logger: logger,
  131. }
  132. for location := range topoMap {
  133. // Sort the location so the result is stable.
  134. sort.Slice(topoMap[location], func(i, j int) bool {
  135. return topoMap[location][i].Name < topoMap[location][j].Name
  136. })
  137. leader := findLeader(topoMap[location])
  138. if location == localLocation && topoMap[location][leader].Name == hostname {
  139. t.leader = true
  140. }
  141. var allowedIPs []net.IPNet
  142. allowedLocationIPsMap := make(map[string]struct{})
  143. var allowedLocationIPs []net.IPNet
  144. var cidrs []*net.IPNet
  145. var hostnames []string
  146. var privateIPs []net.IP
  147. for _, node := range topoMap[location] {
  148. // Allowed IPs should include:
  149. // - the node's allocated subnet
  150. // - the node's WireGuard IP
  151. // - the node's internal IP
  152. // - IPs that were specified by the allowed-location-ips annotation
  153. if node.Subnet != nil {
  154. allowedIPs = append(allowedIPs, *node.Subnet)
  155. }
  156. for _, ip := range node.AllowedLocationIPs {
  157. if _, ok := allowedLocationIPsMap[ip.String()]; !ok {
  158. allowedLocationIPs = append(allowedLocationIPs, ip)
  159. allowedLocationIPsMap[ip.String()] = struct{}{}
  160. }
  161. }
  162. if node.InternalIP != nil {
  163. allowedIPs = append(allowedIPs, *oneAddressCIDR(node.InternalIP.IP))
  164. privateIPs = append(privateIPs, node.InternalIP.IP)
  165. }
  166. cidrs = append(cidrs, node.Subnet)
  167. hostnames = append(hostnames, node.Name)
  168. }
  169. // The sorting has no function, but makes testing easier.
  170. sort.Slice(allowedLocationIPs, func(i, j int) bool {
  171. return allowedLocationIPs[i].String() < allowedLocationIPs[j].String()
  172. })
  173. t.segments = append(t.segments, &segment{
  174. allowedIPs: allowedIPs,
  175. endpoint: topoMap[location][leader].Endpoint,
  176. key: topoMap[location][leader].Key,
  177. persistentKeepalive: topoMap[location][leader].PersistentKeepalive,
  178. location: location,
  179. cidrs: cidrs,
  180. hostnames: hostnames,
  181. leader: leader,
  182. privateIPs: privateIPs,
  183. allowedLocationIPs: allowedLocationIPs,
  184. })
  185. _ = level.Debug(t.logger).Log("msg", "generated segment", "location", location, "allowedIPs", allowedIPs, "endpoint", topoMap[location][leader].Endpoint, "cidrs", cidrs, "hostnames", hostnames, "leader", leader, "privateIPs", privateIPs, "allowedLocationIPs", allowedLocationIPs)
  186. }
  187. // Sort the Topology segments so the result is stable.
  188. sort.Slice(t.segments, func(i, j int) bool {
  189. return t.segments[i].location < t.segments[j].location
  190. })
  191. for _, peer := range peers {
  192. t.peers = append(t.peers, peer)
  193. }
  194. // Sort the Topology peers so the result is stable.
  195. sort.Slice(t.peers, func(i, j int) bool {
  196. return t.peers[i].Name < t.peers[j].Name
  197. })
  198. // We need to defensively deduplicate peer allowed IPs. If two peers claim the same IP,
  199. // the WireGuard configuration could flap, causing the interface to churn.
  200. t.peers = deduplicatePeerIPs(t.peers)
  201. // Copy the host node DiscoveredEndpoints in the topology as a starting point.
  202. for key := range nodes[hostname].DiscoveredEndpoints {
  203. t.discoveredEndpoints[key] = nodes[hostname].DiscoveredEndpoints[key]
  204. }
  205. // Allocate IPs to the segment leaders in a stable, coordination-free manner.
  206. a := newAllocator(*subnet)
  207. for _, segment := range t.segments {
  208. ipNet := a.next()
  209. if ipNet == nil {
  210. return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
  211. }
  212. segment.wireGuardIP = ipNet.IP
  213. segment.allowedIPs = append(segment.allowedIPs, *oneAddressCIDR(ipNet.IP))
  214. if t.leader && segment.location == t.location {
  215. t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
  216. }
  217. // Now that the topology is ordered, update the discoveredEndpoints map
  218. // add new ones by going through the ordered topology: segments, nodes
  219. for _, node := range topoMap[segment.location] {
  220. for key := range node.DiscoveredEndpoints {
  221. if _, ok := t.discoveredEndpoints[key]; !ok {
  222. t.discoveredEndpoints[key] = node.DiscoveredEndpoints[key]
  223. }
  224. }
  225. }
  226. // Check for intersecting IPs in allowed location IPs
  227. segment.allowedLocationIPs = t.filterAllowedLocationIPs(segment.allowedLocationIPs, segment.location)
  228. }
  229. _ = level.Debug(t.logger).Log("msg", "generated topology", "location", t.location, "hostname", t.hostname, "wireGuardIP", t.wireGuardCIDR, "privateIP", t.privateIP, "subnet", t.subnet, "leader", t.leader)
  230. return &t, nil
  231. }
  232. func intersect(n1, n2 net.IPNet) bool {
  233. return n1.Contains(n2.IP) || n2.Contains(n1.IP)
  234. }
  235. func (t *Topology) filterAllowedLocationIPs(ips []net.IPNet, location string) (ret []net.IPNet) {
  236. CheckIPs:
  237. for _, ip := range ips {
  238. for _, s := range t.segments {
  239. // Check if allowed location IPs are also allowed in other locations.
  240. if location != s.location {
  241. for _, i := range s.allowedLocationIPs {
  242. if intersect(ip, i) {
  243. _ = level.Warn(t.logger).Log("msg", "overlapping allowed location IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
  244. continue CheckIPs
  245. }
  246. }
  247. }
  248. // Check if allowed location IPs intersect with the allowed IPs.
  249. // If the allowed location IP strictly contains an allowed IP, that's
  250. // fine - the more specific route will be used. Reject if the allowed
  251. // IP contains or equals the allowed location IP.
  252. for _, i := range s.allowedIPs {
  253. if i.Contains(ip.IP) {
  254. _ = level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with allowed IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
  255. continue CheckIPs
  256. }
  257. }
  258. // Check if allowed location IPs intersect with the private IPs of the segment.
  259. // If the allowed location IP fully contains a private IP, that's fine.
  260. for _, i := range s.privateIPs {
  261. if ip.Contains(i) {
  262. // This is OK - the allowed location IP contains the private IP,
  263. // so the more specific route to the private IP will still work.
  264. _ = level.Debug(t.logger).Log("msg", "allowed location IPnet contains privateIP", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
  265. }
  266. }
  267. }
  268. // Check if allowed location IPs intersect with allowed IPs of peers.
  269. for _, p := range t.peers {
  270. for _, i := range p.AllowedIPs {
  271. if intersect(ip, i) {
  272. _ = level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with peer IPnet", "IP", ip.String(), "IP2", i.String(), "peer", p.Name)
  273. continue CheckIPs
  274. }
  275. }
  276. }
  277. ret = append(ret, ip)
  278. }
  279. return
  280. }
  281. func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key wgtypes.Key, persistentKeepalive *time.Duration) *wireguard.Endpoint {
  282. // Do not update non-nat peers
  283. if persistentKeepalive == nil || *persistentKeepalive == time.Duration(0) {
  284. return endpoint
  285. }
  286. e, ok := t.discoveredEndpoints[key.String()]
  287. if ok {
  288. return wireguard.NewEndpointFromUDPAddr(e)
  289. }
  290. return endpoint
  291. }
  292. // Conf generates a WireGuard configuration file for a given Topology.
  293. func (t *Topology) Conf() *wireguard.Conf {
  294. c := &wireguard.Conf{
  295. Config: wgtypes.Config{
  296. PrivateKey: &t.key,
  297. ListenPort: &t.port,
  298. ReplacePeers: true,
  299. },
  300. }
  301. for _, s := range t.segments {
  302. if s.location == t.location {
  303. continue
  304. }
  305. peer := wireguard.Peer{
  306. PeerConfig: wgtypes.PeerConfig{
  307. AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
  308. PersistentKeepaliveInterval: &t.persistentKeepalive,
  309. PublicKey: s.key,
  310. ReplaceAllowedIPs: true,
  311. },
  312. Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
  313. }
  314. c.Peers = append(c.Peers, peer)
  315. }
  316. for _, p := range t.peers {
  317. peer := wireguard.Peer{
  318. PeerConfig: wgtypes.PeerConfig{
  319. AllowedIPs: p.AllowedIPs,
  320. PersistentKeepaliveInterval: &t.persistentKeepalive,
  321. PresharedKey: p.PresharedKey,
  322. PublicKey: p.PublicKey,
  323. ReplaceAllowedIPs: true,
  324. },
  325. Endpoint: t.updateEndpoint(p.Endpoint, p.PublicKey, p.PersistentKeepaliveInterval),
  326. }
  327. c.Peers = append(c.Peers, peer)
  328. }
  329. return c
  330. }
  331. // AsPeer generates the WireGuard peer configuration for the local location of the given Topology.
  332. // This configuration can be used to configure this location as a peer of another WireGuard interface.
  333. func (t *Topology) AsPeer() *wireguard.Peer {
  334. for _, s := range t.segments {
  335. if s.location != t.location {
  336. continue
  337. }
  338. p := &wireguard.Peer{
  339. PeerConfig: wgtypes.PeerConfig{
  340. AllowedIPs: s.allowedIPs,
  341. PublicKey: s.key,
  342. },
  343. Endpoint: s.endpoint,
  344. }
  345. return p
  346. }
  347. return nil
  348. }
  349. // PeerConf generates a WireGuard configuration file for a given peer in a Topology.
  350. func (t *Topology) PeerConf(name string) *wireguard.Conf {
  351. var pka *time.Duration
  352. var psk *wgtypes.Key
  353. for i := range t.peers {
  354. if t.peers[i].Name == name {
  355. pka = t.peers[i].PersistentKeepaliveInterval
  356. psk = t.peers[i].PresharedKey
  357. break
  358. }
  359. }
  360. c := &wireguard.Conf{}
  361. for _, s := range t.segments {
  362. peer := wireguard.Peer{
  363. PeerConfig: wgtypes.PeerConfig{
  364. AllowedIPs: append(s.allowedIPs, s.allowedLocationIPs...),
  365. PersistentKeepaliveInterval: pka,
  366. PresharedKey: psk,
  367. PublicKey: s.key,
  368. },
  369. Endpoint: t.updateEndpoint(s.endpoint, s.key, &s.persistentKeepalive),
  370. }
  371. c.Peers = append(c.Peers, peer)
  372. }
  373. for i := range t.peers {
  374. if t.peers[i].Name == name {
  375. continue
  376. }
  377. peer := wireguard.Peer{
  378. PeerConfig: wgtypes.PeerConfig{
  379. AllowedIPs: t.peers[i].AllowedIPs,
  380. PersistentKeepaliveInterval: pka,
  381. PublicKey: t.peers[i].PublicKey,
  382. },
  383. Endpoint: t.updateEndpoint(t.peers[i].Endpoint, t.peers[i].PublicKey, t.peers[i].PersistentKeepaliveInterval),
  384. }
  385. c.Peers = append(c.Peers, peer)
  386. }
  387. return c
  388. }
  389. // oneAddressCIDR takes an IP address and returns a CIDR
  390. // that contains only that address.
  391. func oneAddressCIDR(ip net.IP) *net.IPNet {
  392. return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}
  393. }
  394. // findLeader selects a leader for the nodes in a segment;
  395. // it will select the first node that says it should lead
  396. // or the first node in the segment if none have volunteered,
  397. // always preferring those with a public external IP address,
  398. func findLeader(nodes []*Node) int {
  399. var leaders, public []int
  400. for i := range nodes {
  401. if nodes[i].Leader {
  402. if isPublic(nodes[i].Endpoint.IP()) {
  403. return i
  404. }
  405. leaders = append(leaders, i)
  406. }
  407. if nodes[i].Endpoint.IP() != nil && isPublic(nodes[i].Endpoint.IP()) {
  408. public = append(public, i)
  409. }
  410. }
  411. if len(leaders) != 0 {
  412. return leaders[0]
  413. }
  414. if len(public) != 0 {
  415. return public[0]
  416. }
  417. return 0
  418. }
  419. func deduplicatePeerIPs(peers []*Peer) []*Peer {
  420. ps := make([]*Peer, len(peers))
  421. ips := make(map[string]struct{})
  422. for i, peer := range peers {
  423. p := Peer{
  424. Name: peer.Name,
  425. Peer: wireguard.Peer{
  426. PeerConfig: wgtypes.PeerConfig{
  427. PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval,
  428. PresharedKey: peer.PresharedKey,
  429. PublicKey: peer.PublicKey,
  430. },
  431. Endpoint: peer.Endpoint,
  432. },
  433. }
  434. for _, ip := range peer.AllowedIPs {
  435. if _, ok := ips[ip.String()]; ok {
  436. continue
  437. }
  438. p.AllowedIPs = append(p.AllowedIPs, ip)
  439. ips[ip.String()] = struct{}{}
  440. }
  441. ps[i] = &p
  442. }
  443. return ps
  444. }