topology.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. // Copyright 2019 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package mesh
  15. import (
  16. "bytes"
  17. "errors"
  18. "fmt"
  19. "net"
  20. "sort"
  21. "strings"
  22. "text/template"
  23. "github.com/vishvananda/netlink"
  24. "golang.org/x/sys/unix"
  25. )
  26. var (
  27. confTemplate = template.Must(template.New("").Parse(`[Interface]
  28. PrivateKey = {{.Key}}
  29. ListenPort = {{.Port}}
  30. {{range .Segments -}}
  31. {{if ne .Location $.Location}}
  32. [Peer]
  33. PublicKey = {{.Key}}
  34. Endpoint = {{.Endpoint}}:{{$.Port}}
  35. AllowedIPs = {{.AllowedIPs}}
  36. {{end}}
  37. {{- end -}}
  38. `))
  39. )
  40. // Topology represents the logical structure of the overlay network.
  41. type Topology struct {
  42. // Some fields need to be exported so that the template can read them.
  43. Key string
  44. Port int
  45. // Location is the logical location of the local host.
  46. Location string
  47. Segments []*segment
  48. // hostname is the hostname of the local host.
  49. hostname string
  50. // leader represents whether or not the local host
  51. // is the segment leader.
  52. leader bool
  53. // subnet is the entire subnet from which IPs
  54. // for the WireGuard interfaces will be allocated.
  55. subnet *net.IPNet
  56. // privateIP is the private IP address of the local node.
  57. privateIP *net.IPNet
  58. // wireGuardCIDR is the allocated CIDR of the WireGuard
  59. // interface of the local node. If the local node is not
  60. // the leader, then it is nil.
  61. wireGuardCIDR *net.IPNet
  62. }
  63. type segment struct {
  64. // Some fields need to be exported so that the template can read them.
  65. AllowedIPs string
  66. Endpoint string
  67. Key string
  68. // Location is the logical location of this segment.
  69. Location string
  70. // cidrs is a slice of subnets of all peers in the segment.
  71. cidrs []*net.IPNet
  72. // hostnames is a slice of the hostnames of the peers in the segment.
  73. hostnames []string
  74. // leader is the index of the leader of the segment.
  75. leader int
  76. // privateIPs is a slice of private IPs of all peers in the segment.
  77. privateIPs []net.IP
  78. // wireGuardIP is the allocated IP address of the WireGuard
  79. // interface on the leader of the segment.
  80. wireGuardIP net.IP
  81. }
  82. // NewTopology creates a new Topology struct from a given set of nodes.
  83. func NewTopology(nodes map[string]*Node, granularity Granularity, hostname string, port int, key []byte, subnet *net.IPNet) (*Topology, error) {
  84. topoMap := make(map[string][]*Node)
  85. for _, node := range nodes {
  86. var location string
  87. switch granularity {
  88. case DataCenterGranularity:
  89. location = node.Location
  90. case NodeGranularity:
  91. location = node.Name
  92. }
  93. topoMap[location] = append(topoMap[location], node)
  94. }
  95. var localLocation string
  96. switch granularity {
  97. case DataCenterGranularity:
  98. localLocation = nodes[hostname].Location
  99. case NodeGranularity:
  100. localLocation = hostname
  101. }
  102. t := Topology{Key: strings.TrimSpace(string(key)), Port: port, hostname: hostname, Location: localLocation, subnet: subnet, privateIP: nodes[hostname].InternalIP}
  103. for location := range topoMap {
  104. // Sort the location so the result is stable.
  105. sort.Slice(topoMap[location], func(i, j int) bool {
  106. return topoMap[location][i].Name < topoMap[location][j].Name
  107. })
  108. leader := findLeader(topoMap[location])
  109. if location == localLocation && topoMap[location][leader].Name == hostname {
  110. t.leader = true
  111. }
  112. var allowedIPs []string
  113. var cidrs []*net.IPNet
  114. var hostnames []string
  115. var privateIPs []net.IP
  116. for _, node := range topoMap[location] {
  117. // Allowed IPs should include:
  118. // - the node's allocated subnet
  119. // - the node's WireGuard IP
  120. // - the node's internal IP
  121. allowedIPs = append(allowedIPs, node.Subnet.String(), oneAddressCIDR(node.InternalIP.IP).String())
  122. cidrs = append(cidrs, node.Subnet)
  123. hostnames = append(hostnames, node.Name)
  124. privateIPs = append(privateIPs, node.InternalIP.IP)
  125. }
  126. t.Segments = append(t.Segments, &segment{
  127. AllowedIPs: strings.Join(allowedIPs, ", "),
  128. Endpoint: topoMap[location][leader].ExternalIP.IP.String(),
  129. Key: strings.TrimSpace(string(topoMap[location][leader].Key)),
  130. Location: location,
  131. cidrs: cidrs,
  132. hostnames: hostnames,
  133. leader: leader,
  134. privateIPs: privateIPs,
  135. })
  136. }
  137. // Sort the Topology so the result is stable.
  138. sort.Slice(t.Segments, func(i, j int) bool {
  139. return t.Segments[i].Location < t.Segments[j].Location
  140. })
  141. // Allocate IPs to the segment leaders in a stable, coordination-free manner.
  142. a := newAllocator(*subnet)
  143. for _, segment := range t.Segments {
  144. ipNet := a.next()
  145. if ipNet == nil {
  146. return nil, errors.New("failed to allocate an IP address; ran out of IP addresses")
  147. }
  148. segment.wireGuardIP = ipNet.IP
  149. segment.AllowedIPs = fmt.Sprintf("%s, %s", segment.AllowedIPs, ipNet.String())
  150. if t.leader && segment.Location == t.Location {
  151. t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: t.subnet.Mask}
  152. }
  153. }
  154. return &t, nil
  155. }
  156. // RemoteSubnets identifies the subnets of the hosts in segments different than the host's.
  157. func (t *Topology) RemoteSubnets() []*net.IPNet {
  158. var remote []*net.IPNet
  159. for _, s := range t.Segments {
  160. if s == nil || s.Location == t.Location {
  161. continue
  162. }
  163. remote = append(remote, s.cidrs...)
  164. }
  165. return remote
  166. }
  167. // Routes generates a slice of routes for a given Topology.
  168. func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encapsulate Encapsulate) []*netlink.Route {
  169. var routes []*netlink.Route
  170. if !t.leader {
  171. // Find the leader for this segment.
  172. var leader net.IP
  173. for _, segment := range t.Segments {
  174. if segment.Location == t.Location {
  175. leader = segment.privateIPs[segment.leader]
  176. break
  177. }
  178. }
  179. for _, segment := range t.Segments {
  180. // First, add a route to the WireGuard IP of the segment.
  181. routes = append(routes, encapsulateRoute(&netlink.Route{
  182. Dst: oneAddressCIDR(segment.wireGuardIP),
  183. Flags: int(netlink.FLAG_ONLINK),
  184. Gw: leader,
  185. LinkIndex: privIface,
  186. Protocol: unix.RTPROT_STATIC,
  187. }, encapsulate, t.privateIP, tunlIface))
  188. // Add routes for the current segment if local is true.
  189. if segment.Location == t.Location {
  190. if local {
  191. for i := range segment.cidrs {
  192. // Don't add routes for the local node.
  193. if segment.privateIPs[i].Equal(t.privateIP.IP) {
  194. continue
  195. }
  196. routes = append(routes, encapsulateRoute(&netlink.Route{
  197. Dst: segment.cidrs[i],
  198. Flags: int(netlink.FLAG_ONLINK),
  199. Gw: segment.privateIPs[i],
  200. LinkIndex: privIface,
  201. Protocol: unix.RTPROT_STATIC,
  202. }, encapsulate, t.privateIP, tunlIface))
  203. }
  204. }
  205. continue
  206. }
  207. for i := range segment.cidrs {
  208. // Add routes to the Pod CIDRs of nodes in other segments.
  209. routes = append(routes, encapsulateRoute(&netlink.Route{
  210. Dst: segment.cidrs[i],
  211. Flags: int(netlink.FLAG_ONLINK),
  212. Gw: leader,
  213. LinkIndex: privIface,
  214. Protocol: unix.RTPROT_STATIC,
  215. }, encapsulate, t.privateIP, tunlIface))
  216. // Add routes to the private IPs of nodes in other segments.
  217. // Number of CIDRs and private IPs always match so
  218. // we can reuse the loop.
  219. routes = append(routes, encapsulateRoute(&netlink.Route{
  220. Dst: oneAddressCIDR(segment.privateIPs[i]),
  221. Flags: int(netlink.FLAG_ONLINK),
  222. Gw: leader,
  223. LinkIndex: privIface,
  224. Protocol: unix.RTPROT_STATIC,
  225. }, encapsulate, t.privateIP, tunlIface))
  226. }
  227. }
  228. return routes
  229. }
  230. for _, segment := range t.Segments {
  231. // Add routes for the current segment if local is true.
  232. if segment.Location == t.Location {
  233. if local {
  234. for i := range segment.cidrs {
  235. // Don't add routes for the local node.
  236. if segment.privateIPs[i].Equal(t.privateIP.IP) {
  237. continue
  238. }
  239. routes = append(routes, encapsulateRoute(&netlink.Route{
  240. Dst: segment.cidrs[i],
  241. Flags: int(netlink.FLAG_ONLINK),
  242. Gw: segment.privateIPs[i],
  243. LinkIndex: privIface,
  244. Protocol: unix.RTPROT_STATIC,
  245. }, encapsulate, t.privateIP, tunlIface))
  246. }
  247. }
  248. continue
  249. }
  250. for i := range segment.cidrs {
  251. // Add routes to the Pod CIDRs of nodes in other segments.
  252. routes = append(routes, &netlink.Route{
  253. Dst: segment.cidrs[i],
  254. Flags: int(netlink.FLAG_ONLINK),
  255. Gw: segment.wireGuardIP,
  256. LinkIndex: kiloIface,
  257. Protocol: unix.RTPROT_STATIC,
  258. })
  259. // Add routes to the private IPs of nodes in other segments.
  260. // Number of CIDRs and private IPs always match so
  261. // we can reuse the loop.
  262. routes = append(routes, &netlink.Route{
  263. Dst: oneAddressCIDR(segment.privateIPs[i]),
  264. Flags: int(netlink.FLAG_ONLINK),
  265. Gw: segment.wireGuardIP,
  266. LinkIndex: kiloIface,
  267. Protocol: unix.RTPROT_STATIC,
  268. })
  269. }
  270. }
  271. return routes
  272. }
  273. func encapsulateRoute(route *netlink.Route, encapsulate Encapsulate, subnet *net.IPNet, tunlIface int) *netlink.Route {
  274. if encapsulate == AlwaysEncapsulate || (encapsulate == CrossSubnetEncapsulate && !subnet.Contains(route.Gw)) {
  275. route.LinkIndex = tunlIface
  276. }
  277. return route
  278. }
  279. // Conf generates a WireGuard configuration file for a given Topology.
  280. func (t *Topology) Conf() ([]byte, error) {
  281. conf := new(bytes.Buffer)
  282. if err := confTemplate.Execute(conf, t); err != nil {
  283. return nil, err
  284. }
  285. return conf.Bytes(), nil
  286. }
  287. // oneAddressCIDR takes an IP address and returns a CIDR
  288. // that contains only that address.
  289. func oneAddressCIDR(ip net.IP) *net.IPNet {
  290. return &net.IPNet{IP: ip, Mask: net.CIDRMask(len(ip)*8, len(ip)*8)}
  291. }
  292. // findLeader selects a leader for the nodes in a segment;
  293. // it will select the first node that says it should lead
  294. // or the first node in the segment if none have volunteered,
  295. // always preferring those with a public external IP address,
  296. func findLeader(nodes []*Node) int {
  297. var leaders, public []int
  298. for i := range nodes {
  299. if nodes[i].Leader {
  300. if isPublic(nodes[i].ExternalIP) {
  301. return i
  302. }
  303. leaders = append(leaders, i)
  304. }
  305. if isPublic(nodes[i].ExternalIP) {
  306. public = append(public, i)
  307. }
  308. }
  309. if len(leaders) != 0 {
  310. return leaders[0]
  311. }
  312. if len(public) != 0 {
  313. return public[0]
  314. }
  315. return 0
  316. }