topology.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. // Copyright 2019 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package mesh
  15. import (
  16. "errors"
  17. "net"
  18. "sort"
  19. "github.com/squat/kilo/pkg/encapsulation"
  20. "github.com/squat/kilo/pkg/iptables"
  21. "github.com/squat/kilo/pkg/wireguard"
  22. "github.com/vishvananda/netlink"
  23. "golang.org/x/sys/unix"
  24. )
  25. const kiloTableIndex = 1107
  26. // Topology represents the logical structure of the overlay network.
  27. type Topology struct {
  28. // key is the private key of the node creating the topology.
  29. key []byte
  30. port uint32
  31. // Location is the logical location of the local host.
  32. location string
  33. segments []*segment
  34. peers []*Peer
  35. // hostname is the hostname of the local host.
  36. hostname string
  37. // leader represents whether or not the local host
  38. // is the segment leader.
  39. leader bool
  40. // persistentKeepalive is the interval in seconds of the emission
  41. // of keepalive packets by the local node to its peers.
  42. persistentKeepalive int
  43. // privateIP is the private IP address of the local node.
  44. privateIP *net.IPNet
  45. // subnet is the Pod subnet of the local node.
  46. subnet *net.IPNet
  47. // wireGuardCIDR is the allocated CIDR of the WireGuard
  48. // interface of the local node. If the local node is not
  49. // the leader, then it is nil.
  50. wireGuardCIDR *net.IPNet
  51. }
  52. type segment struct {
  53. allowedIPs []*net.IPNet
  54. endpoint *wireguard.Endpoint
  55. key []byte
  56. // Location is the logical location of this segment.
  57. location string
  58. // cidrs is a slice of subnets of all peers in the segment.
  59. cidrs []*net.IPNet
  60. // hostnames is a slice of the hostnames of the peers in the segment.
  61. hostnames []string
  62. // leader is the index of the leader of the segment.
  63. leader int
  64. // privateIPs is a slice of private IPs of all peers in the segment.
  65. privateIPs []net.IP
  66. // wireGuardIP is the allocated IP address of the WireGuard
  67. // interface on the leader of the segment.
  68. wireGuardIP net.IP
  69. }
  70. // NewTopology creates a new Topology struct from a given set of nodes and peers.
  71. func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) (*Topology, error) {
  72. topoMap := make(map[string][]*Node)
  73. for _, node := range nodes {
  74. var location string
  75. switch granularity {
  76. case LogicalGranularity:
  77. location = node.Location
  78. case FullGranularity:
  79. location = node.Name
  80. }
  81. topoMap[location] = append(topoMap[location], node)
  82. }
  83. var localLocation string
  84. switch granularity {
  85. case LogicalGranularity:
  86. localLocation = nodes[hostname].Location
  87. case FullGranularity:
  88. localLocation = hostname
  89. }
  90. t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet}
  91. for location := range topoMap {
  92. // Sort the location so the result is stable.
  93. sort.Slice(topoMap[location], func(i, j int) bool {
  94. return topoMap[location][i].Name < topoMap[location][j].Name
  95. })
  96. leader := findLeader(topoMap[location])
  97. if location == localLocation && topoMap[location][leader].Name == hostname {
  98. t.leader = true
  99. }
  100. var allowedIPs []*net.IPNet
  101. var cidrs []*net.IPNet
  102. var hostnames []string
  103. var privateIPs []net.IP
  104. for _, node := range topoMap[location] {
  105. // Allowed IPs should include:
  106. // - the node's allocated subnet
  107. // - the node's WireGuard IP
  108. // - the node's internal IP
  109. allowedIPs = append(allowedIPs, node.Subnet, oneAddressCIDR(node.InternalIP.IP))
  110. cidrs = append(cidrs, node.Subnet)
  111. hostnames = append(hostnames, node.Name)
  112. privateIPs = append(privateIPs, node.InternalIP.IP)
  113. }
  114. t.segments = append(t.segments, &segment{
  115. allowedIPs: allowedIPs,
  116. endpoint: topoMap[location][leader].Endpoint,
  117. key: topoMap[location][leader].Key,
  118. location: location,
  119. cidrs: cidrs,
  120. hostnames: hostnames,
  121. leader: leader,
  122. privateIPs: privateIPs,
  123. })
  124. }
  125. // Sort the Topology segments so the result is stable.
  126. sort.Slice(t.segments, func(i, j int) bool {
  127. return t.segments[i].location < t.segments[j].location
  128. })
  129. for _, peer := range peers {
  130. t.peers = append(t.peers, peer)
  131. }
  132. // Sort the Topology peers so the result is stable.
  133. sort.Slice(t.peers, func(i, j int) bool {
  134. return t.peers[i].Name < t.peers[j].Name
  135. })
  136. // We need to defensively deduplicate peer allowed IPs. If two peers claim the same IP,
  137. // the WireGuard configuration could flap, causing the interface to churn.
  138. t.peers = deduplicatePeerIPs(t.peers)
  139. // Allocate IPs to the segment leaders in a stable, coordination-free manner.
  140. a := newAllocator(*subnet)
  141. for _, segment := range t.segments {
  142. ipNet := a.next()
  143. if ipNet == nil {
  144. return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
  145. }
  146. segment.wireGuardIP = ipNet.IP
  147. segment.allowedIPs = append(segment.allowedIPs, oneAddressCIDR(ipNet.IP))
  148. if t.leader && segment.location == t.location {
  149. t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
  150. }
  151. }
  152. return &t, nil
  153. }
  154. // Routes generates a slice of routes for a given Topology.
  155. func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface int, local bool, enc encapsulation.Encapsulator) ([]*netlink.Route, []*netlink.Rule) {
  156. var routes []*netlink.Route
  157. var rules []*netlink.Rule
  158. if !t.leader {
  159. // Find the GW for this segment.
  160. // This will be the an IP of the leader.
  161. // In an IPIP encapsulated mesh it is the leader's private IP.
  162. var gw net.IP
  163. for _, segment := range t.segments {
  164. if segment.location == t.location {
  165. gw = enc.Gw(segment.endpoint.IP, segment.privateIPs[segment.leader], segment.cidrs[segment.leader])
  166. break
  167. }
  168. }
  169. for _, segment := range t.segments {
  170. // First, add a route to the WireGuard IP of the segment.
  171. routes = append(routes, encapsulateRoute(&netlink.Route{
  172. Dst: oneAddressCIDR(segment.wireGuardIP),
  173. Flags: int(netlink.FLAG_ONLINK),
  174. Gw: gw,
  175. LinkIndex: privIface,
  176. Protocol: unix.RTPROT_STATIC,
  177. }, enc.Strategy(), t.privateIP, tunlIface))
  178. // Add routes for the current segment if local is true.
  179. if segment.location == t.location {
  180. if local {
  181. for i := range segment.cidrs {
  182. // Don't add routes for the local node.
  183. if segment.privateIPs[i].Equal(t.privateIP.IP) {
  184. continue
  185. }
  186. routes = append(routes, encapsulateRoute(&netlink.Route{
  187. Dst: segment.cidrs[i],
  188. Flags: int(netlink.FLAG_ONLINK),
  189. Gw: segment.privateIPs[i],
  190. LinkIndex: privIface,
  191. Protocol: unix.RTPROT_STATIC,
  192. }, enc.Strategy(), t.privateIP, tunlIface))
  193. // Encapsulate packets from the host's Pod subnet headed
  194. // to private IPs.
  195. if enc.Strategy() == encapsulation.Always || (enc.Strategy() == encapsulation.CrossSubnet && !t.privateIP.Contains(segment.privateIPs[i])) {
  196. routes = append(routes, &netlink.Route{
  197. Dst: oneAddressCIDR(segment.privateIPs[i]),
  198. Flags: int(netlink.FLAG_ONLINK),
  199. Gw: segment.privateIPs[i],
  200. LinkIndex: tunlIface,
  201. Protocol: unix.RTPROT_STATIC,
  202. Table: kiloTableIndex,
  203. })
  204. rules = append(rules, defaultRule(&netlink.Rule{
  205. Src: t.subnet,
  206. Dst: oneAddressCIDR(segment.privateIPs[i]),
  207. Table: kiloTableIndex,
  208. }))
  209. }
  210. }
  211. }
  212. continue
  213. }
  214. for i := range segment.cidrs {
  215. // Add routes to the Pod CIDRs of nodes in other segments.
  216. routes = append(routes, encapsulateRoute(&netlink.Route{
  217. Dst: segment.cidrs[i],
  218. Flags: int(netlink.FLAG_ONLINK),
  219. Gw: gw,
  220. LinkIndex: privIface,
  221. Protocol: unix.RTPROT_STATIC,
  222. }, enc.Strategy(), t.privateIP, tunlIface))
  223. // Add routes to the private IPs of nodes in other segments.
  224. // Number of CIDRs and private IPs always match so
  225. // we can reuse the loop.
  226. routes = append(routes, encapsulateRoute(&netlink.Route{
  227. Dst: oneAddressCIDR(segment.privateIPs[i]),
  228. Flags: int(netlink.FLAG_ONLINK),
  229. Gw: gw,
  230. LinkIndex: privIface,
  231. Protocol: unix.RTPROT_STATIC,
  232. }, enc.Strategy(), t.privateIP, tunlIface))
  233. }
  234. }
  235. // Add routes for the allowed IPs of peers.
  236. for _, peer := range t.peers {
  237. for i := range peer.AllowedIPs {
  238. routes = append(routes, encapsulateRoute(&netlink.Route{
  239. Dst: peer.AllowedIPs[i],
  240. Flags: int(netlink.FLAG_ONLINK),
  241. Gw: gw,
  242. LinkIndex: privIface,
  243. Protocol: unix.RTPROT_STATIC,
  244. }, enc.Strategy(), t.privateIP, tunlIface))
  245. }
  246. }
  247. return routes, rules
  248. }
  249. for _, segment := range t.segments {
  250. // Add routes for the current segment if local is true.
  251. if segment.location == t.location {
  252. if local {
  253. for i := range segment.cidrs {
  254. // Don't add routes for the local node.
  255. if segment.privateIPs[i].Equal(t.privateIP.IP) {
  256. continue
  257. }
  258. routes = append(routes, encapsulateRoute(&netlink.Route{
  259. Dst: segment.cidrs[i],
  260. Flags: int(netlink.FLAG_ONLINK),
  261. Gw: segment.privateIPs[i],
  262. LinkIndex: privIface,
  263. Protocol: unix.RTPROT_STATIC,
  264. }, enc.Strategy(), t.privateIP, tunlIface))
  265. // Encapsulate packets from the host's Pod subnet headed
  266. // to private IPs.
  267. if enc.Strategy() == encapsulation.Always || (enc.Strategy() == encapsulation.CrossSubnet && !t.privateIP.Contains(segment.privateIPs[i])) {
  268. routes = append(routes, &netlink.Route{
  269. Dst: oneAddressCIDR(segment.privateIPs[i]),
  270. Flags: int(netlink.FLAG_ONLINK),
  271. Gw: segment.privateIPs[i],
  272. LinkIndex: tunlIface,
  273. Protocol: unix.RTPROT_STATIC,
  274. Table: kiloTableIndex,
  275. })
  276. rules = append(rules, defaultRule(&netlink.Rule{
  277. Src: t.subnet,
  278. Dst: oneAddressCIDR(segment.privateIPs[i]),
  279. Table: kiloTableIndex,
  280. }))
  281. // Also encapsulate packets from the Kilo interface
  282. // headed to private IPs.
  283. rules = append(rules, defaultRule(&netlink.Rule{
  284. Dst: oneAddressCIDR(segment.privateIPs[i]),
  285. Table: kiloTableIndex,
  286. IifName: kiloIfaceName,
  287. }))
  288. }
  289. }
  290. }
  291. continue
  292. }
  293. for i := range segment.cidrs {
  294. // Add routes to the Pod CIDRs of nodes in other segments.
  295. routes = append(routes, &netlink.Route{
  296. Dst: segment.cidrs[i],
  297. Flags: int(netlink.FLAG_ONLINK),
  298. Gw: segment.wireGuardIP,
  299. LinkIndex: kiloIface,
  300. Protocol: unix.RTPROT_STATIC,
  301. })
  302. // Don't add routes through Kilo if the private IP
  303. // equals the external IP. This means that the node
  304. // is only accessible through an external IP and we
  305. // cannot encapsulate traffic to an IP through the IP.
  306. if segment.privateIPs[i].Equal(segment.endpoint.IP) {
  307. continue
  308. }
  309. // Add routes to the private IPs of nodes in other segments.
  310. // Number of CIDRs and private IPs always match so
  311. // we can reuse the loop.
  312. routes = append(routes, &netlink.Route{
  313. Dst: oneAddressCIDR(segment.privateIPs[i]),
  314. Flags: int(netlink.FLAG_ONLINK),
  315. Gw: segment.wireGuardIP,
  316. LinkIndex: kiloIface,
  317. Protocol: unix.RTPROT_STATIC,
  318. })
  319. }
  320. }
  321. // Add routes for the allowed IPs of peers.
  322. for _, peer := range t.peers {
  323. for i := range peer.AllowedIPs {
  324. routes = append(routes, &netlink.Route{
  325. Dst: peer.AllowedIPs[i],
  326. LinkIndex: kiloIface,
  327. Protocol: unix.RTPROT_STATIC,
  328. })
  329. }
  330. }
  331. return routes, rules
  332. }
  333. func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy, subnet *net.IPNet, tunlIface int) *netlink.Route {
  334. if encapsulate == encapsulation.Always || (encapsulate == encapsulation.CrossSubnet && !subnet.Contains(route.Gw)) {
  335. route.LinkIndex = tunlIface
  336. }
  337. return route
  338. }
  339. // Conf generates a WireGuard configuration file for a given Topology.
  340. func (t *Topology) Conf() *wireguard.Conf {
  341. c := &wireguard.Conf{
  342. Interface: &wireguard.Interface{
  343. PrivateKey: t.key,
  344. ListenPort: t.port,
  345. },
  346. }
  347. for _, s := range t.segments {
  348. if s.location == t.location {
  349. continue
  350. }
  351. peer := &wireguard.Peer{
  352. AllowedIPs: s.allowedIPs,
  353. Endpoint: s.endpoint,
  354. PersistentKeepalive: t.persistentKeepalive,
  355. PublicKey: s.key,
  356. }
  357. c.Peers = append(c.Peers, peer)
  358. }
  359. for _, p := range t.peers {
  360. peer := &wireguard.Peer{
  361. AllowedIPs: p.AllowedIPs,
  362. Endpoint: p.Endpoint,
  363. PersistentKeepalive: t.persistentKeepalive,
  364. PresharedKey: p.PresharedKey,
  365. PublicKey: p.PublicKey,
  366. }
  367. c.Peers = append(c.Peers, peer)
  368. }
  369. return c
  370. }
  371. // AsPeer generates the WireGuard peer configuration for the local location of the given Topology.
  372. // This configuration can be used to configure this location as a peer of another WireGuard interface.
  373. func (t *Topology) AsPeer() *wireguard.Peer {
  374. for _, s := range t.segments {
  375. if s.location != t.location {
  376. continue
  377. }
  378. return &wireguard.Peer{
  379. AllowedIPs: s.allowedIPs,
  380. Endpoint: s.endpoint,
  381. PublicKey: s.key,
  382. }
  383. }
  384. return nil
  385. }
  386. // PeerConf generates a WireGuard configuration file for a given peer in a Topology.
  387. func (t *Topology) PeerConf(name string) *wireguard.Conf {
  388. var pka int
  389. var psk []byte
  390. for i := range t.peers {
  391. if t.peers[i].Name == name {
  392. pka = t.peers[i].PersistentKeepalive
  393. psk = t.peers[i].PresharedKey
  394. break
  395. }
  396. }
  397. c := &wireguard.Conf{}
  398. for _, s := range t.segments {
  399. peer := &wireguard.Peer{
  400. AllowedIPs: s.allowedIPs,
  401. Endpoint: s.endpoint,
  402. PersistentKeepalive: pka,
  403. PresharedKey: psk,
  404. PublicKey: s.key,
  405. }
  406. c.Peers = append(c.Peers, peer)
  407. }
  408. for i := range t.peers {
  409. if t.peers[i].Name == name {
  410. continue
  411. }
  412. peer := &wireguard.Peer{
  413. AllowedIPs: t.peers[i].AllowedIPs,
  414. PersistentKeepalive: pka,
  415. PublicKey: t.peers[i].PublicKey,
  416. Endpoint: t.peers[i].Endpoint,
  417. }
  418. c.Peers = append(c.Peers, peer)
  419. }
  420. return c
  421. }
  422. // Rules returns the iptables rules required by the local node.
  423. func (t *Topology) Rules(cni bool) []iptables.Rule {
  424. var rules []iptables.Rule
  425. rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
  426. rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
  427. if cni {
  428. rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", t.subnet.String(), "-j", "KILO-NAT"))
  429. }
  430. for _, s := range t.segments {
  431. rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-d", s.wireGuardIP.String(), "-j", "RETURN"))
  432. for _, aip := range s.allowedIPs {
  433. rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-d", aip.String(), "-j", "RETURN"))
  434. }
  435. }
  436. for _, p := range t.peers {
  437. for _, aip := range p.AllowedIPs {
  438. rules = append(rules,
  439. iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", aip.String(), "-j", "KILO-NAT"),
  440. iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-d", aip.String(), "-j", "RETURN"),
  441. )
  442. }
  443. }
  444. rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
  445. rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
  446. return rules
  447. }
  448. // oneAddressCIDR takes an IP address and returns a CIDR
  449. // that contains only that address.
  450. func oneAddressCIDR(ip net.IP) *net.IPNet {
  451. return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}
  452. }
  453. // findLeader selects a leader for the nodes in a segment;
  454. // it will select the first node that says it should lead
  455. // or the first node in the segment if none have volunteered,
  456. // always preferring those with a public external IP address,
  457. func findLeader(nodes []*Node) int {
  458. var leaders, public []int
  459. for i := range nodes {
  460. if nodes[i].Leader {
  461. if isPublic(nodes[i].Endpoint.IP) {
  462. return i
  463. }
  464. leaders = append(leaders, i)
  465. }
  466. if isPublic(nodes[i].Endpoint.IP) {
  467. public = append(public, i)
  468. }
  469. }
  470. if len(leaders) != 0 {
  471. return leaders[0]
  472. }
  473. if len(public) != 0 {
  474. return public[0]
  475. }
  476. return 0
  477. }
  478. func deduplicatePeerIPs(peers []*Peer) []*Peer {
  479. ps := make([]*Peer, len(peers))
  480. ips := make(map[string]struct{})
  481. for i, peer := range peers {
  482. p := Peer{
  483. Name: peer.Name,
  484. Peer: wireguard.Peer{
  485. Endpoint: peer.Endpoint,
  486. PersistentKeepalive: peer.PersistentKeepalive,
  487. PresharedKey: peer.PresharedKey,
  488. PublicKey: peer.PublicKey,
  489. },
  490. }
  491. for _, ip := range peer.AllowedIPs {
  492. if _, ok := ips[ip.String()]; ok {
  493. continue
  494. }
  495. p.AllowedIPs = append(p.AllowedIPs, ip)
  496. ips[ip.String()] = struct{}{}
  497. }
  498. ps[i] = &p
  499. }
  500. return ps
  501. }
  502. func defaultRule(rule *netlink.Rule) *netlink.Rule {
  503. base := netlink.NewRule()
  504. base.Src = rule.Src
  505. base.Dst = rule.Dst
  506. base.IifName = rule.IifName
  507. base.Table = rule.Table
  508. return base
  509. }