Bläddra i källkod

pkg/: FEATURE: support allowed IPs outside a cluster

Users can specify IPs with the annotation "allowed-location-ips".
It makes no difference which node of a location is annotated.
The IP should be routable from the particular location, e.g. a printer in
the same LAN.
This way these IPs become routable from other location.

Signed-off-by: leonnicolas <leonloechner@gmx.de>

Co-authored-by: Lucas Servén Marín <lserven@gmail.com>
leonnicolas 5 år sedan
förälder
incheckning
31ffaa0e71
9 ändrade filer med 368 tillägg och 26 borttagningar
  1. 1 1
      cmd/kgctl/graph.go
  2. 2 2
      cmd/kgctl/showconf.go
  3. 11 0
      pkg/k8s/backend.go
  4. 1 0
      pkg/mesh/backend.go
  5. 15 2
      pkg/mesh/mesh.go
  6. 32 0
      pkg/mesh/routes.go
  7. 56 0
      pkg/mesh/routes_test.go
  8. 76 3
      pkg/mesh/topology.go
  9. 174 18
      pkg/mesh/topology_test.go

+ 1 - 1
cmd/kgctl/graph.go

@@ -60,7 +60,7 @@ func runGraph(_ *cobra.Command, _ []string) error {
 			peers[p.Name] = p
 		}
 	}
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, []byte{}, subnet, nodes[hostname].PersistentKeepalive)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, []byte{}, subnet, nodes[hostname].PersistentKeepalive, nil)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %v", err)
 	}

+ 2 - 2
cmd/kgctl/showconf.go

@@ -147,7 +147,7 @@ func runShowConfNode(_ *cobra.Command, args []string) error {
 		}
 	}
 
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, nodes[hostname].PersistentKeepalive)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, []byte{}, subnet, nodes[hostname].PersistentKeepalive, nil)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %v", err)
 	}
@@ -236,7 +236,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
 		return fmt.Errorf("did not find any peer named %q in the cluster", peer)
 	}
 
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, []byte{}, subnet, peers[peer].PersistentKeepalive)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, []byte{}, subnet, peers[peer].PersistentKeepalive, nil)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %v", err)
 	}

+ 11 - 0
pkg/k8s/backend.go

@@ -59,6 +59,7 @@ const (
 	persistentKeepaliveKey       = "kilo.squat.ai/persistent-keepalive"
 	wireGuardIPAnnotationKey     = "kilo.squat.ai/wireguard-ip"
 	discoveredEndpointsKey       = "kilo.squat.ai/discovered-endpoints"
+	allowedLocationIPsKey        = "kilo.squat.ai/allowed-location-ips"
 	// RegionLabelKey is the key for the well-known Kubernetes topology region label.
 	RegionLabelKey  = "topology.kubernetes.io/region"
 	jsonPatchSlash  = "~1"
@@ -311,6 +312,15 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
 			discoveredEndpoints = nil
 		}
 	}
+	// Set allowed IPs for a location.
+	var allowedLocationIPs []*net.IPNet
+	if str, ok := node.ObjectMeta.Annotations[allowedLocationIPsKey]; ok {
+		for _, ip := range strings.Split(str, ",") {
+			if ipnet := normalizeIP(ip); ipnet != nil {
+				allowedLocationIPs = append(allowedLocationIPs, ipnet)
+			}
+		}
+	}
 
 	return &mesh.Node{
 		// Endpoint and InternalIP should only ever fail to parse if the
@@ -334,6 +344,7 @@ func translateNode(node *v1.Node, topologyLabel string) *mesh.Node {
 		// will parse as nil.
 		WireGuardIP:         normalizeIP(node.ObjectMeta.Annotations[wireGuardIPAnnotationKey]),
 		DiscoveredEndpoints: discoveredEndpoints,
+		AllowedLocationIPs:  allowedLocationIPs,
 	}
 }
 

+ 1 - 0
pkg/mesh/backend.go

@@ -67,6 +67,7 @@ type Node struct {
 	Subnet              *net.IPNet
 	WireGuardIP         *net.IPNet
 	DiscoveredEndpoints map[string]*wireguard.Endpoint
+	AllowedLocationIPs  []*net.IPNet
 }
 
 // Ready indicates whether or not the node is ready.

+ 15 - 2
pkg/mesh/mesh.go

@@ -380,6 +380,7 @@ func (m *Mesh) handleLocal(n *Node) {
 		Subnet:              n.Subnet,
 		WireGuardIP:         m.wireGuardIP,
 		DiscoveredEndpoints: n.DiscoveredEndpoints,
+		AllowedLocationIPs:  n.AllowedLocationIPs,
 	}
 	if !nodesAreEqual(n, local) {
 		level.Debug(m.logger).Log("msg", "local node differs from backend")
@@ -460,7 +461,7 @@ func (m *Mesh) applyTopology() {
 	oldConf := wireguard.Parse(oldConfRaw)
 	natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
 	nodes[m.hostname].DiscoveredEndpoints = natEndpoints
-	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive)
+	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
 	if err != nil {
 		level.Error(m.logger).Log("error", err)
 		m.errorCounter.WithLabelValues("apply").Inc()
@@ -674,7 +675,7 @@ func nodesAreEqual(a, b *Node) bool {
 	// Ignore LastSeen when comparing equality we want to check if the nodes are
 	// equivalent. However, we do want to check if LastSeen has transitioned
 	// between valid and invalid.
-	return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints)
+	return string(a.Key) == string(b.Key) && ipNetsEqual(a.WireGuardIP, b.WireGuardIP) && ipNetsEqual(a.InternalIP, b.InternalIP) && a.Leader == b.Leader && a.Location == b.Location && a.Name == b.Name && subnetsEqual(a.Subnet, b.Subnet) && a.Ready() == b.Ready() && a.PersistentKeepalive == b.PersistentKeepalive && discoveredEndpointsAreEqual(a.DiscoveredEndpoints, b.DiscoveredEndpoints) && ipNetSlicesEqual(a.AllowedLocationIPs, b.AllowedLocationIPs)
 }
 
 func peersAreEqual(a, b *Peer) bool {
@@ -713,6 +714,18 @@ func ipNetsEqual(a, b *net.IPNet) bool {
 	return a.IP.Equal(b.IP)
 }
 
+func ipNetSlicesEqual(a, b []*net.IPNet) bool {
+	if len(a) != len(b) {
+		return false
+	}
+	for i := range a {
+		if !ipNetsEqual(a[i], b[i]) {
+			return false
+		}
+	}
+	return true
+}
+
 func subnetsEqual(a, b *net.IPNet) bool {
 	if a == nil && b == nil {
 		return true

+ 32 - 0
pkg/mesh/routes.go

@@ -108,6 +108,17 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
 					Protocol:  unix.RTPROT_STATIC,
 				}, enc.Strategy(), t.privateIP, tunlIface))
 			}
+			// For segments / locations other than the location of this instance of kg,
+			// we need to set routes for allowed location IPs over the leader in the current location.
+			for i := range segment.allowedLocationIPs {
+				routes = append(routes, encapsulateRoute(&netlink.Route{
+					Dst:       segment.allowedLocationIPs[i],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        gw,
+					LinkIndex: privIface,
+					Protocol:  unix.RTPROT_STATIC,
+				}, enc.Strategy(), t.privateIP, tunlIface))
+			}
 		}
 		// Add routes for the allowed IPs of peers.
 		for _, peer := range t.peers {
@@ -198,6 +209,17 @@ func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface
 				Protocol:  unix.RTPROT_STATIC,
 			})
 		}
+		// For segments / locations other than the location of this instance of kg,
+		// we need to set routes for allowed location IPs over the wg interface.
+		for i := range segment.allowedLocationIPs {
+			routes = append(routes, &netlink.Route{
+				Dst:       segment.allowedLocationIPs[i],
+				Flags:     int(netlink.FLAG_ONLINK),
+				Gw:        segment.wireGuardIP,
+				LinkIndex: kiloIface,
+				Protocol:  unix.RTPROT_STATIC,
+			})
+		}
 	}
 	// Add routes for the allowed IPs of peers.
 	for _, peer := range t.peers {
@@ -232,6 +254,16 @@ func (t *Topology) Rules(cni bool) []iptables.Rule {
 		for _, aip := range s.allowedIPs {
 			rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
 		}
+		// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
+		// Otherwise packets to these destinations will reach the destination, but never find their way back.
+		// We only want to NAT in locations of the corresponding allowed location IPs.
+		if t.location == s.location {
+			for _, alip := range s.allowedLocationIPs {
+				rules = append(rules,
+					iptables.NewRule(iptables.GetProtocol(len(alip.IP)), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
+				)
+			}
+		}
 	}
 	for _, p := range t.peers {
 		for _, aip := range p.AllowedIPs {

+ 56 - 0
pkg/mesh/routes_test.go

@@ -74,6 +74,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[2].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -258,6 +265,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["d"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       peers["a"].AllowedIPs[0],
 					LinkIndex: kiloIface,
@@ -294,6 +308,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[2].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -422,6 +443,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[3].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -480,6 +508,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       nodes["d"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -538,6 +573,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       nodes["d"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -875,6 +917,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       nodes["c"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -1005,6 +1054,13 @@ func TestRoutes(t *testing.T) {
 					LinkIndex: kiloIface,
 					Protocol:  unix.RTPROT_STATIC,
 				},
+				{
+					Dst:       nodes["b"].AllowedLocationIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
 				{
 					Dst:       nodes["d"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),

+ 76 - 3
pkg/mesh/topology.go

@@ -19,6 +19,9 @@ import (
 	"net"
 	"sort"
 
+	"github.com/go-kit/kit/log"
+	"github.com/go-kit/kit/log/level"
+
 	"github.com/squat/kilo/pkg/wireguard"
 )
 
@@ -57,6 +60,7 @@ type Topology struct {
 	wireGuardCIDR *net.IPNet
 	// discoveredEndpoints is the updated map of valid discovered Endpoints
 	discoveredEndpoints map[string]*wireguard.Endpoint
+	logger              log.Logger
 }
 
 type segment struct {
@@ -78,10 +82,17 @@ type segment struct {
 	// wireGuardIP is the allocated IP address of the WireGuard
 	// interface on the leader of the segment.
 	wireGuardIP net.IP
+	// allowedLocationIPs are not part of the cluster and are not peers.
+	// They are directly routable from nodes within the segment.
+	// A classic example is a printer that ought to be routable from other locations.
+	allowedLocationIPs []*net.IPNet
 }
 
 // NewTopology creates a new Topology struct from a given set of nodes and peers.
-func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) (*Topology, error) {
+func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int, logger log.Logger) (*Topology, error) {
+	if logger == nil {
+		logger = log.NewNopLogger()
+	}
 	topoMap := make(map[string][]*Node)
 	for _, node := range nodes {
 		var location string
@@ -109,7 +120,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 		localLocation = nodeLocationPrefix + hostname
 	}
 
-	t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint)}
+	t := Topology{key: key, port: port, hostname: hostname, location: localLocation, persistentKeepalive: persistentKeepalive, privateIP: nodes[hostname].InternalIP, subnet: nodes[hostname].Subnet, wireGuardCIDR: subnet, discoveredEndpoints: make(map[string]*wireguard.Endpoint), logger: logger}
 	for location := range topoMap {
 		// Sort the location so the result is stable.
 		sort.Slice(topoMap[location], func(i, j int) bool {
@@ -120,6 +131,8 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 			t.leader = true
 		}
 		var allowedIPs []*net.IPNet
+		allowedLocationIPsMap := make(map[string]struct{})
+		var allowedLocationIPs []*net.IPNet
 		var cidrs []*net.IPNet
 		var hostnames []string
 		var privateIPs []net.IP
@@ -128,7 +141,14 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 			// - the node's allocated subnet
 			// - the node's WireGuard IP
 			// - the node's internal IP
+			// - IPs that were specified by the allowed-location-ips annotation
 			allowedIPs = append(allowedIPs, node.Subnet)
+			for _, ip := range node.AllowedLocationIPs {
+				if _, ok := allowedLocationIPsMap[ip.String()]; !ok {
+					allowedLocationIPs = append(allowedLocationIPs, ip)
+					allowedLocationIPsMap[ip.String()] = struct{}{}
+				}
+			}
 			if node.InternalIP != nil {
 				allowedIPs = append(allowedIPs, oneAddressCIDR(node.InternalIP.IP))
 				privateIPs = append(privateIPs, node.InternalIP.IP)
@@ -136,6 +156,10 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 			cidrs = append(cidrs, node.Subnet)
 			hostnames = append(hostnames, node.Name)
 		}
+		// The sorting has no function, but makes testing easier.
+		sort.Slice(allowedLocationIPs, func(i, j int) bool {
+			return allowedLocationIPs[i].String() < allowedLocationIPs[j].String()
+		})
 		t.segments = append(t.segments, &segment{
 			allowedIPs:          allowedIPs,
 			endpoint:            topoMap[location][leader].Endpoint,
@@ -146,6 +170,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 			hostnames:           hostnames,
 			leader:              leader,
 			privateIPs:          privateIPs,
+			allowedLocationIPs:  allowedLocationIPs,
 		})
 	}
 	// Sort the Topology segments so the result is stable.
@@ -189,11 +214,59 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 				}
 			}
 		}
+		// Check for intersecting IPs in allowed location IPs
+		segment.allowedLocationIPs = t.filterAllowedLocationIPs(segment.allowedLocationIPs, segment.location)
 	}
 
 	return &t, nil
 }
 
+func intersect(n1, n2 *net.IPNet) bool {
+	return n1.Contains(n2.IP) || n2.Contains(n1.IP)
+}
+
+func (t *Topology) filterAllowedLocationIPs(ips []*net.IPNet, location string) (ret []*net.IPNet) {
+CheckIPs:
+	for _, ip := range ips {
+		for _, s := range t.segments {
+			// Check if allowed location IPs are also allowed in other locations.
+			if location != s.location {
+				for _, i := range s.allowedLocationIPs {
+					if intersect(ip, i) {
+						level.Warn(t.logger).Log("msg", "overlapping allowed location IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
+						continue CheckIPs
+					}
+				}
+			}
+			// Check if allowed location IPs intersect with the allowed IPs.
+			for _, i := range s.allowedIPs {
+				if intersect(ip, i) {
+					level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with allowed IPnets", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
+					continue CheckIPs
+				}
+			}
+			// Check if allowed location IPs intersect with the private IPs of the segment.
+			for _, i := range s.privateIPs {
+				if ip.Contains(i) {
+					level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with privateIP", "IP", ip.String(), "IP2", i.String(), "segment-location", s.location)
+					continue CheckIPs
+				}
+			}
+		}
+		// Check if allowed location IPs intersect with allowed IPs of peers.
+		for _, p := range t.peers {
+			for _, i := range p.AllowedIPs {
+				if intersect(ip, i) {
+					level.Warn(t.logger).Log("msg", "overlapping allowed location IPnet with peer IPnet", "IP", ip.String(), "IP2", i.String(), "peer", p.Name)
+					continue CheckIPs
+				}
+			}
+		}
+		ret = append(ret, ip)
+	}
+	return
+}
+
 func (t *Topology) updateEndpoint(endpoint *wireguard.Endpoint, key []byte, persistentKeepalive int) *wireguard.Endpoint {
 	// Do not update non-nat peers
 	if persistentKeepalive == 0 {
@@ -219,7 +292,7 @@ func (t *Topology) Conf() *wireguard.Conf {
 			continue
 		}
 		peer := &wireguard.Peer{
-			AllowedIPs:          s.allowedIPs,
+			AllowedIPs:          append(s.allowedIPs, s.allowedLocationIPs...),
 			Endpoint:            t.updateEndpoint(s.endpoint, s.key, s.persistentKeepalive),
 			PersistentKeepalive: t.persistentKeepalive,
 			PublicKey:           s.key,

+ 174 - 18
pkg/mesh/topology_test.go

@@ -19,6 +19,7 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/go-kit/kit/log"
 	"github.com/kylelemons/godebug/pretty"
 
 	"github.com/squat/kilo/pkg/wireguard"
@@ -28,6 +29,15 @@ func allowedIPs(ips ...string) string {
 	return strings.Join(ips, ", ")
 }
 
+func mustParseCIDR(s string) (r *net.IPNet) {
+	if _, ip, err := net.ParseCIDR(s); err != nil {
+		panic("failed to parse CIDR")
+	} else {
+		r = ip
+	}
+	return
+}
+
 func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
 	key := []byte("private")
 	e1 := &net.IPNet{IP: net.ParseIP("10.1.0.1").To4(), Mask: net.CIDRMask(16, 32)}
@@ -36,6 +46,7 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
 	e4 := &net.IPNet{IP: net.ParseIP("10.1.0.4").To4(), Mask: net.CIDRMask(16, 32)}
 	i1 := &net.IPNet{IP: net.ParseIP("192.168.0.1").To4(), Mask: net.CIDRMask(32, 32)}
 	i2 := &net.IPNet{IP: net.ParseIP("192.168.0.2").To4(), Mask: net.CIDRMask(32, 32)}
+	i3 := &net.IPNet{IP: net.ParseIP("192.168.178.3").To4(), Mask: net.CIDRMask(32, 32)}
 	nodes := map[string]*Node{
 		"a": {
 			Name:                "a",
@@ -47,12 +58,13 @@ func setup(t *testing.T) (map[string]*Node, map[string]*Peer, []byte, uint32) {
 			PersistentKeepalive: 25,
 		},
 		"b": {
-			Name:       "b",
-			Endpoint:   &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
-			InternalIP: i1,
-			Location:   "2",
-			Subnet:     &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
-			Key:        []byte("key2"),
+			Name:               "b",
+			Endpoint:           &wireguard.Endpoint{DNSOrIP: wireguard.DNSOrIP{IP: e2.IP}, Port: DefaultKiloPort},
+			InternalIP:         i1,
+			Location:           "2",
+			Subnet:             &net.IPNet{IP: net.ParseIP("10.2.2.0"), Mask: net.CIDRMask(24, 32)},
+			Key:                []byte("key2"),
+			AllowedLocationIPs: []*net.IPNet{i3},
 		},
 		"c": {
 			Name:       "c",
@@ -146,6 +158,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b", "c"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -159,7 +172,8 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w3,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 		{
@@ -195,6 +209,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b", "c"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -208,7 +223,8 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w3,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 		{
@@ -244,6 +260,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b", "c"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP, nodes["c"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["d"].Subnet, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -257,7 +274,8 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w3,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 		{
@@ -293,6 +311,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -317,7 +336,8 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w4,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 		{
@@ -353,6 +373,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -377,7 +398,8 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w4,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 		{
@@ -413,6 +435,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -437,7 +460,8 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w4,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 		{
@@ -473,6 +497,7 @@ func TestNewTopology(t *testing.T) {
 						hostnames:           []string{"b"},
 						privateIPs:          []net.IP{nodes["b"].InternalIP.IP},
 						wireGuardIP:         w2,
+						allowedLocationIPs:  nodes["b"].AllowedLocationIPs,
 					},
 					{
 						allowedIPs:          []*net.IPNet{nodes["c"].Subnet, nodes["c"].InternalIP, {IP: w3, Mask: net.CIDRMask(32, 32)}},
@@ -497,13 +522,14 @@ func TestNewTopology(t *testing.T) {
 						wireGuardIP:         w4,
 					},
 				},
-				peers: []*Peer{peers["a"], peers["b"]},
+				peers:  []*Peer{peers["a"], peers["b"]},
+				logger: log.NewNopLogger(),
 			},
 		},
 	} {
 		tc.result.key = key
 		tc.result.port = port
-		topo, err := NewTopology(nodes, peers, tc.granularity, tc.hostname, port, key, DefaultKiloSubnet, 0)
+		topo, err := NewTopology(nodes, peers, tc.granularity, tc.hostname, port, key, DefaultKiloSubnet, 0, nil)
 		if err != nil {
 			t.Errorf("test case %q: failed to generate Topology: %v", tc.name, err)
 		}
@@ -514,7 +540,7 @@ func TestNewTopology(t *testing.T) {
 }
 
 func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port uint32, key []byte, subnet *net.IPNet, persistentKeepalive int) *Topology {
-	topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive)
+	topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive, nil)
 	if err != nil {
 		t.Errorf("failed to generate Topology: %v", err)
 	}
@@ -538,7 +564,7 @@ ListenPort = 51820
 [Peer]
 PublicKey = key2
 Endpoint = 10.1.0.2:51820
-AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32
+AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.2.3.0/24, 192.168.0.2/32, 10.4.0.2/32, 192.168.178.3/32
 PersistentKeepalive = 25
 
 [Peer]
@@ -623,7 +649,7 @@ PersistentKeepalive = 25
 		[Peer]
 		PublicKey = key2
 		Endpoint = 10.1.0.2:51820
-		AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
+		AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
 		PersistentKeepalive = 25
 
 		[Peer]
@@ -697,7 +723,7 @@ PersistentKeepalive = 25
 		[Peer]
 		PublicKey = key2
 		Endpoint = 10.1.0.2:51820
-		AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32
+		AllowedIPs = 10.2.2.0/24, 192.168.0.1/32, 10.4.0.2/32, 192.168.178.3/32
 
 		[Peer]
 		PublicKey = key4
@@ -953,3 +979,133 @@ func TestDeduplicatePeerIPs(t *testing.T) {
 		}
 	}
 }
+
+func TestFilterAllowedIPs(t *testing.T) {
+	nodes, peers, key, port := setup(t)
+	topo := mustTopo(t, nodes, peers, LogicalGranularity, nodes["a"].Name, port, key, DefaultKiloSubnet, nodes["a"].PersistentKeepalive)
+	for _, tc := range []struct {
+		name               string
+		allowedLocationIPs map[int][]*net.IPNet
+		result             map[int][]*net.IPNet
+	}{
+		{
+			name: "nothing to filter",
+			allowedLocationIPs: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("192.168.178.4/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.6/32"),
+					mustParseCIDR("192.168.178.7/32"),
+				},
+			},
+			result: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("192.168.178.4/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.6/32"),
+					mustParseCIDR("192.168.178.7/32"),
+				},
+			},
+		},
+		{
+			name: "intersections between segments",
+			allowedLocationIPs: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("192.168.178.4/32"),
+					mustParseCIDR("192.168.178.8/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.6/32"),
+					mustParseCIDR("192.168.178.7/32"),
+					mustParseCIDR("192.168.178.4/32"),
+				},
+			},
+			result: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("192.168.178.8/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.6/32"),
+					mustParseCIDR("192.168.178.7/32"),
+					mustParseCIDR("192.168.178.4/32"),
+				},
+			},
+		},
+		{
+			name: "intersections with wireGuardCIDR",
+			allowedLocationIPs: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("10.4.0.1/32"),
+					mustParseCIDR("192.168.178.8/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.6/32"),
+					mustParseCIDR("192.168.178.7/32"),
+				},
+			},
+			result: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("192.168.178.8/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.6/32"),
+					mustParseCIDR("192.168.178.7/32"),
+				},
+			},
+		},
+		{
+			name: "intersections with more than one allowedLocationIPs",
+			allowedLocationIPs: map[int][]*net.IPNet{
+				0: {
+					mustParseCIDR("192.168.178.8/32"),
+				},
+				1: {
+					mustParseCIDR("192.168.178.5/32"),
+				},
+				2: {
+					mustParseCIDR("192.168.178.7/24"),
+				},
+			},
+			result: map[int][]*net.IPNet{
+				0: {},
+				1: {},
+				2: {
+					mustParseCIDR("192.168.178.7/24"),
+				},
+			},
+		},
+	} {
+		for k, v := range tc.allowedLocationIPs {
+			topo.segments[k].allowedLocationIPs = v
+		}
+		for k, v := range topo.segments {
+			f := topo.filterAllowedLocationIPs(v.allowedLocationIPs, v.location)
+			// Overwrite the allowedLocationIPs to mimic the actual usage of the filterAllowedLocationIPs function.
+			topo.segments[k].allowedLocationIPs = f
+			if !ipNetSlicesEqual(f, tc.result[k]) {
+				t.Errorf("test case %q:\n\texpected:\n\t%q\n\tgot:\n\t%q\n", tc.name, tc.result[k], f)
+			}
+		}
+
+	}
+}