Procházet zdrojové kódy

pkg/route,pkg/mesh: replace NAT with ip rules

This commit entirely replaces NAT in Kilo with a few iproute2 rules.
Previously, Kilo would source-NAT the majority of packets in order to
avoid problems with strict source checks in cloud providers causing
packets to be considered martians. This source-NAT-ing made it
difficult to correctly apply Kuberenetes NetworkPolicies based on source
IPs.

This rewrite instead relies on a handful of iproute2 rules to ensure
that packets get encapsulated in certain scenarios based on the source
network and/or source interface.

This has the benefit of avoiding extra iptables bloat as well as
enabling better compatibility with NetworkPolicies.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
Lucas Servén Marín před 6 roky
rodič
revize
134cbe90be
5 změnil soubory, kde provedl 605 přidání a 151 odebrání
  1. 12 12
      pkg/mesh/mesh.go
  2. 59 7
      pkg/mesh/topology.go
  3. 233 22
      pkg/mesh/topology_test.go
  4. 113 41
      pkg/route/route.go
  5. 188 69
      pkg/route/route_test.go

+ 12 - 12
pkg/mesh/mesh.go

@@ -604,14 +604,7 @@ func (m *Mesh) applyTopology() {
 		m.errorCounter.WithLabelValues("apply").Inc()
 		return
 	}
-	rules := iptables.ForwardRules(m.subnet)
-	// Finx the Kilo interface name.
-	link, err := linkByIndex(m.kiloIface)
-	if err != nil {
-		level.Error(m.logger).Log("error", err)
-		m.errorCounter.WithLabelValues("apply").Inc()
-		return
-	}
+	ipRules := iptables.ForwardRules(m.subnet)
 	// If we are handling local routes, ensure the local
 	// tunnel has an IP address and IPIP traffic is allowed.
 	if m.enc.Strategy() != encapsulation.Never && m.local {
@@ -624,7 +617,7 @@ func (m *Mesh) applyTopology() {
 				break
 			}
 		}
-		rules = append(rules, m.enc.Rules(cidrs)...)
+		ipRules = append(ipRules, m.enc.Rules(cidrs)...)
 
 		// If we are handling local routes, ensure the local
 		// tunnel has an IP address.
@@ -634,7 +627,14 @@ func (m *Mesh) applyTopology() {
 			return
 		}
 	}
-	if err := m.ipTables.Set(rules); err != nil {
+	if err := m.ipTables.Set(ipRules); err != nil {
+		level.Error(m.logger).Log("error", err)
+		m.errorCounter.WithLabelValues("apply").Inc()
+		return
+	}
+	// Find the Kilo interface name.
+	link, err := linkByIndex(m.kiloIface)
+	if err != nil {
 		level.Error(m.logger).Log("error", err)
 		m.errorCounter.WithLabelValues("apply").Inc()
 		return
@@ -677,8 +677,8 @@ func (m *Mesh) applyTopology() {
 	}
 	// We need to add routes last since they may depend
 	// on the WireGuard interface.
-	routes := t.Routes(m.kiloIface, m.privIface, m.enc.Index(), m.local, m.enc)
-	if err := m.table.Set(routes); err != nil {
+	routes, rules := t.Routes(link.Attrs().Name, m.kiloIface, m.privIface, m.enc.Index(), m.local, m.enc)
+	if err := m.table.Set(routes, rules); err != nil {
 		level.Error(m.logger).Log("error", err)
 		m.errorCounter.WithLabelValues("apply").Inc()
 	}

+ 59 - 7
pkg/mesh/topology.go

@@ -25,6 +25,8 @@ import (
 	"golang.org/x/sys/unix"
 )
 
+const kiloTableIndex = 1107
+
 // Topology represents the logical structure of the overlay network.
 type Topology struct {
 	// key is the private key of the node creating the topology.
@@ -40,8 +42,7 @@ type Topology struct {
 	// leader represents whether or not the local host
 	// is the segment leader.
 	leader bool
-	// subnet is the entire subnet from which IPs
-	// for the WireGuard interfaces will be allocated.
+	// subnet is the Pod subnet of the local node.
 	subnet *net.IPNet
 	// privateIP is the private IP address  of the local node.
 	privateIP *net.IPNet
@@ -95,7 +96,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 		localLocation = hostname
 	}
 
-	t := Topology{key: key, port: port, hostname: hostname, location: localLocation, subnet: subnet, privateIP: nodes[hostname].InternalIP}
+	t := Topology{key: key, port: port, hostname: hostname, location: localLocation, subnet: nodes[hostname].Subnet, privateIP: nodes[hostname].InternalIP}
 	for location := range topoMap {
 		// Sort the location so the result is stable.
 		sort.Slice(topoMap[location], func(i, j int) bool {
@@ -156,7 +157,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 		segment.wireGuardIP = ipNet.IP
 		segment.allowedIPs = append(segment.allowedIPs, oneAddressCIDR(ipNet.IP))
 		if t.leader && segment.location == t.location {
-			t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: t.subnet.Mask}
+			t.wireGuardCIDR = &net.IPNet{IP: ipNet.IP, Mask: subnet.Mask}
 		}
 	}
 
@@ -164,8 +165,9 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 }
 
 // Routes generates a slice of routes for a given Topology.
-func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, enc encapsulation.Encapsulator) []*netlink.Route {
+func (t *Topology) Routes(kiloIfaceName string, kiloIface, privIface, tunlIface int, local bool, enc encapsulation.Encapsulator) ([]*netlink.Route, []*netlink.Rule) {
 	var routes []*netlink.Route
+	var rules []*netlink.Rule
 	if !t.leader {
 		// Find the GW for this segment.
 		// This will be the an IP of the leader.
@@ -201,6 +203,23 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, enc e
 							LinkIndex: privIface,
 							Protocol:  unix.RTPROT_STATIC,
 						}, enc.Strategy(), t.privateIP, tunlIface))
+						// Encapsulate packets from the host's Pod subnet headed
+						// to private IPs.
+						if enc.Strategy() == encapsulation.Always || (enc.Strategy() == encapsulation.CrossSubnet && !t.privateIP.Contains(segment.privateIPs[i])) {
+							routes = append(routes, &netlink.Route{
+								Dst:       oneAddressCIDR(segment.privateIPs[i]),
+								Flags:     int(netlink.FLAG_ONLINK),
+								Gw:        segment.privateIPs[i],
+								LinkIndex: tunlIface,
+								Protocol:  unix.RTPROT_STATIC,
+								Table:     kiloTableIndex,
+							})
+							rules = append(rules, defaultRule(&netlink.Rule{
+								Src:   t.subnet,
+								Dst:   oneAddressCIDR(segment.privateIPs[i]),
+								Table: kiloTableIndex,
+							}))
+						}
 					}
 				}
 				continue
@@ -238,7 +257,7 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, enc e
 				}, enc.Strategy(), t.privateIP, tunlIface))
 			}
 		}
-		return routes
+		return routes, rules
 	}
 	for _, segment := range t.segments {
 		// Add routes for the current segment if local is true.
@@ -256,6 +275,30 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, enc e
 						LinkIndex: privIface,
 						Protocol:  unix.RTPROT_STATIC,
 					}, enc.Strategy(), t.privateIP, tunlIface))
+					// Encapsulate packets from the host's Pod subnet headed
+					// to private IPs.
+					if enc.Strategy() == encapsulation.Always || (enc.Strategy() == encapsulation.CrossSubnet && !t.privateIP.Contains(segment.privateIPs[i])) {
+						routes = append(routes, &netlink.Route{
+							Dst:       oneAddressCIDR(segment.privateIPs[i]),
+							Flags:     int(netlink.FLAG_ONLINK),
+							Gw:        segment.privateIPs[i],
+							LinkIndex: tunlIface,
+							Protocol:  unix.RTPROT_STATIC,
+							Table:     kiloTableIndex,
+						})
+						rules = append(rules, defaultRule(&netlink.Rule{
+							Src:   t.subnet,
+							Dst:   oneAddressCIDR(segment.privateIPs[i]),
+							Table: kiloTableIndex,
+						}))
+						// Also encapsulate packets from the Kilo interface
+						// headed to private IPs.
+						rules = append(rules, defaultRule(&netlink.Rule{
+							Dst:     oneAddressCIDR(segment.privateIPs[i]),
+							Table:   kiloTableIndex,
+							IifName: kiloIfaceName,
+						}))
+					}
 				}
 			}
 			continue
@@ -298,7 +341,7 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, enc e
 			})
 		}
 	}
-	return routes
+	return routes, rules
 }
 
 func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy, subnet *net.IPNet, tunlIface int) *netlink.Route {
@@ -448,3 +491,12 @@ func deduplicatePeerIPs(peers []*Peer) []*Peer {
 	}
 	return ps
 }
+
+func defaultRule(rule *netlink.Rule) *netlink.Rule {
+	base := netlink.NewRule()
+	base.Src = rule.Src
+	base.Dst = rule.Dst
+	base.IifName = rule.IifName
+	base.Table = rule.Table
+	return base
+}

+ 233 - 22
pkg/mesh/topology_test.go

@@ -113,7 +113,7 @@ func TestNewTopology(t *testing.T) {
 				hostname:      nodes["a"].Name,
 				leader:        true,
 				location:      nodes["a"].Location,
-				subnet:        DefaultKiloSubnet,
+				subnet:        nodes["a"].Subnet,
 				privateIP:     nodes["a"].InternalIP,
 				wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
@@ -150,7 +150,7 @@ func TestNewTopology(t *testing.T) {
 				hostname:      nodes["b"].Name,
 				leader:        true,
 				location:      nodes["b"].Location,
-				subnet:        DefaultKiloSubnet,
+				subnet:        nodes["b"].Subnet,
 				privateIP:     nodes["b"].InternalIP,
 				wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
@@ -187,7 +187,7 @@ func TestNewTopology(t *testing.T) {
 				hostname:      nodes["c"].Name,
 				leader:        false,
 				location:      nodes["b"].Location,
-				subnet:        DefaultKiloSubnet,
+				subnet:        nodes["c"].Subnet,
 				privateIP:     nodes["c"].InternalIP,
 				wireGuardCIDR: nil,
 				segments: []*segment{
@@ -224,7 +224,7 @@ func TestNewTopology(t *testing.T) {
 				hostname:      nodes["a"].Name,
 				leader:        true,
 				location:      nodes["a"].Name,
-				subnet:        DefaultKiloSubnet,
+				subnet:        nodes["a"].Subnet,
 				privateIP:     nodes["a"].InternalIP,
 				wireGuardCIDR: &net.IPNet{IP: w1, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
@@ -271,7 +271,7 @@ func TestNewTopology(t *testing.T) {
 				hostname:      nodes["b"].Name,
 				leader:        true,
 				location:      nodes["b"].Name,
-				subnet:        DefaultKiloSubnet,
+				subnet:        nodes["b"].Subnet,
 				privateIP:     nodes["b"].InternalIP,
 				wireGuardCIDR: &net.IPNet{IP: w2, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
@@ -318,7 +318,7 @@ func TestNewTopology(t *testing.T) {
 				hostname:      nodes["c"].Name,
 				leader:        true,
 				location:      nodes["c"].Name,
-				subnet:        DefaultKiloSubnet,
+				subnet:        nodes["c"].Subnet,
 				privateIP:     nodes["c"].InternalIP,
 				wireGuardCIDR: &net.IPNet{IP: w3, Mask: net.CIDRMask(16, 32)},
 				segments: []*segment{
@@ -382,7 +382,7 @@ func TestRoutes(t *testing.T) {
 	nodes, peers, key, port := setup(t)
 	kiloIface := 0
 	privIface := 1
-	pubIface := 2
+	tunlIface := 2
 	mustTopoForGranularityAndHost := func(granularity Granularity, hostname string) *Topology {
 		return mustTopo(t, nodes, peers, granularity, hostname, port, key, DefaultKiloSubnet)
 	}
@@ -391,12 +391,15 @@ func TestRoutes(t *testing.T) {
 		name     string
 		local    bool
 		topology *Topology
-		result   []*netlink.Route
+		strategy encapsulation.Strategy
+		routes   []*netlink.Route
+		rules    []*netlink.Rule
 	}{
 		{
 			name:     "logical from a",
 			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -445,7 +448,8 @@ func TestRoutes(t *testing.T) {
 		{
 			name:     "logical from b",
 			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["b"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       mustTopoForGranularityAndHost(LogicalGranularity, nodes["b"].Name).segments[0].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -480,7 +484,8 @@ func TestRoutes(t *testing.T) {
 		{
 			name:     "logical from c",
 			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       oneAddressCIDR(mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name).segments[0].wireGuardIP),
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -535,7 +540,8 @@ func TestRoutes(t *testing.T) {
 		{
 			name:     "full from a",
 			topology: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name).segments[1].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -584,7 +590,8 @@ func TestRoutes(t *testing.T) {
 		{
 			name:     "full from b",
 			topology: mustTopoForGranularityAndHost(FullGranularity, nodes["b"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       mustTopoForGranularityAndHost(FullGranularity, nodes["b"].Name).segments[0].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -633,7 +640,8 @@ func TestRoutes(t *testing.T) {
 		{
 			name:     "full from c",
 			topology: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name).segments[0].cidrs[0],
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -683,7 +691,59 @@ func TestRoutes(t *testing.T) {
 			name:     "logical from a local",
 			local:    true,
 			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
+				{
+					Dst:       nodes["b"].Subnet,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       oneAddressCIDR(nodes["b"].InternalIP.IP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       nodes["c"].Subnet,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       oneAddressCIDR(nodes["c"].InternalIP.IP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name).segments[1].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["a"].AllowedIPs[0],
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["a"].AllowedIPs[1],
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["b"].AllowedIPs[0],
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+			},
+		},
+		{
+			name:     "logical from a local always",
+			local:    true,
+			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["a"].Name),
+			strategy: encapsulation.Always,
+			routes: []*netlink.Route{
 				{
 					Dst:       nodes["b"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -733,7 +793,8 @@ func TestRoutes(t *testing.T) {
 			name:     "logical from b local",
 			local:    true,
 			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["b"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       nodes["a"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -772,11 +833,76 @@ func TestRoutes(t *testing.T) {
 				},
 			},
 		},
+		{
+			name:     "logical from b local always",
+			local:    true,
+			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["b"].Name),
+			strategy: encapsulation.Always,
+			routes: []*netlink.Route{
+				{
+					Dst:       nodes["a"].Subnet,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["b"].Name).segments[0].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       oneAddressCIDR(nodes["a"].InternalIP.IP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        mustTopoForGranularityAndHost(LogicalGranularity, nodes["b"].Name).segments[0].wireGuardIP,
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       nodes["c"].Subnet,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["c"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       oneAddressCIDR(nodes["c"].InternalIP.IP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["c"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+					Table:     kiloTableIndex,
+				},
+				{
+					Dst:       peers["a"].AllowedIPs[0],
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["a"].AllowedIPs[1],
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["b"].AllowedIPs[0],
+					LinkIndex: kiloIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+			},
+			rules: []*netlink.Rule{
+				defaultRule(&netlink.Rule{
+					Src:   nodes["b"].Subnet,
+					Dst:   nodes["c"].InternalIP,
+					Table: kiloTableIndex,
+				}),
+				defaultRule(&netlink.Rule{
+					Dst:     nodes["c"].InternalIP,
+					IifName: DefaultKiloInterface,
+					Table:   kiloTableIndex,
+				}),
+			},
+		},
 		{
 			name:     "logical from c local",
 			local:    true,
 			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       oneAddressCIDR(mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name).segments[0].wireGuardIP),
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -835,11 +961,91 @@ func TestRoutes(t *testing.T) {
 				},
 			},
 		},
+		{
+			name:     "logical from c local always",
+			local:    true,
+			topology: mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name),
+			strategy: encapsulation.Always,
+			routes: []*netlink.Route{
+				{
+					Dst:       oneAddressCIDR(mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name).segments[0].wireGuardIP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       nodes["a"].Subnet,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       oneAddressCIDR(nodes["a"].InternalIP.IP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       oneAddressCIDR(mustTopoForGranularityAndHost(LogicalGranularity, nodes["c"].Name).segments[1].wireGuardIP),
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       nodes["b"].Subnet,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       nodes["b"].InternalIP,
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+					Table:     kiloTableIndex,
+				},
+				{
+					Dst:       peers["a"].AllowedIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["a"].AllowedIPs[1],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+				{
+					Dst:       peers["b"].AllowedIPs[0],
+					Flags:     int(netlink.FLAG_ONLINK),
+					Gw:        nodes["b"].InternalIP.IP,
+					LinkIndex: tunlIface,
+					Protocol:  unix.RTPROT_STATIC,
+				},
+			},
+			rules: []*netlink.Rule{
+				defaultRule(&netlink.Rule{
+					Src:   nodes["c"].Subnet,
+					Dst:   nodes["b"].InternalIP,
+					Table: kiloTableIndex,
+				}),
+			},
+		},
 		{
 			name:     "full from a local",
 			local:    true,
 			topology: mustTopoForGranularityAndHost(FullGranularity, nodes["a"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       nodes["b"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -889,7 +1095,8 @@ func TestRoutes(t *testing.T) {
 			name:     "full from b local",
 			local:    true,
 			topology: mustTopoForGranularityAndHost(FullGranularity, nodes["b"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       nodes["a"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -939,7 +1146,8 @@ func TestRoutes(t *testing.T) {
 			name:     "full from c local",
 			local:    true,
 			topology: mustTopoForGranularityAndHost(FullGranularity, nodes["c"].Name),
-			result: []*netlink.Route{
+			strategy: encapsulation.Never,
+			routes: []*netlink.Route{
 				{
 					Dst:       nodes["a"].Subnet,
 					Flags:     int(netlink.FLAG_ONLINK),
@@ -986,8 +1194,11 @@ func TestRoutes(t *testing.T) {
 			},
 		},
 	} {
-		routes := tc.topology.Routes(kiloIface, privIface, pubIface, tc.local, encapsulation.NewIPIP(encapsulation.Never))
-		if diff := pretty.Compare(routes, tc.result); diff != "" {
+		routes, rules := tc.topology.Routes(DefaultKiloInterface, kiloIface, privIface, tunlIface, tc.local, encapsulation.NewIPIP(tc.strategy))
+		if diff := pretty.Compare(routes, tc.routes); diff != "" {
+			t.Errorf("test case %q: got diff: %v", tc.name, diff)
+		}
+		if diff := pretty.Compare(rules, tc.rules); diff != "" {
 			t.Errorf("test case %q: got diff: %v", tc.name, diff)
 		}
 	}

+ 113 - 41
pkg/route/route.go

@@ -28,22 +28,24 @@ import (
 type Table struct {
 	errors     chan error
 	mu         sync.Mutex
-	routes     map[string]*netlink.Route
+	rs         map[string]interface{}
 	subscribed bool
 
 	// Make these functions fields to allow
 	// for testing.
-	add func(*netlink.Route) error
-	del func(*netlink.Route) error
+	addRoute func(*netlink.Route) error
+	delRoute func(*netlink.Route) error
+	addRule  func(*netlink.Rule) error
+	delRule  func(*netlink.Rule) error
 }
 
 // NewTable generates a new table.
 func NewTable() *Table {
 	return &Table{
-		errors: make(chan error),
-		routes: make(map[string]*netlink.Route),
-		add:    netlink.RouteReplace,
-		del: func(r *netlink.Route) error {
+		errors:   make(chan error),
+		rs:       make(map[string]interface{}),
+		addRoute: netlink.RouteReplace,
+		delRoute: func(r *netlink.Route) error {
 			name := routeToString(r)
 			if name == "" {
 				return errors.New("attempting to delete invalid route")
@@ -59,10 +61,27 @@ func NewTable() *Table {
 			}
 			return nil
 		},
+		addRule: netlink.RuleAdd,
+		delRule: func(r *netlink.Rule) error {
+			name := ruleToString(r)
+			if name == "" {
+				return errors.New("attempting to delete invalid rule")
+			}
+			rules, err := netlink.RuleList(netlink.FAMILY_ALL)
+			if err != nil {
+				return fmt.Errorf("failed to list rules before deletion: %v", err)
+			}
+			for _, rule := range rules {
+				if ruleToString(&rule) == name {
+					return netlink.RuleDel(r)
+				}
+			}
+			return nil
+		},
 	}
 }
 
-// Run watches for changes to routes in the table and reconciles
+// Run watches for changes to routes and rules in the table and reconciles
 // the table against the desired state.
 func (t *Table) Run(stop <-chan struct{}) (<-chan error, error) {
 	t.mu.Lock()
@@ -90,16 +109,19 @@ func (t *Table) Run(stop <-chan struct{}) (<-chan error, error) {
 			// Watch for deleted routes to reconcile this table's routes.
 			case unix.RTM_DELROUTE:
 				t.mu.Lock()
-				for _, r := range t.routes {
-					// Filter out invalid routes.
-					if r == nil || r.Dst == nil {
-						continue
-					}
-					// If any deleted route's destination matches a destination
-					// in the table, reset the corresponding route just in case.
-					if r.Dst.IP.Equal(e.Route.Dst.IP) && r.Dst.Mask.String() == e.Route.Dst.Mask.String() {
-						if err := t.add(r); err != nil {
-							nonBlockingSend(t.errors, fmt.Errorf("failed add route: %v", err))
+				for k := range t.rs {
+					switch r := t.rs[k].(type) {
+					case *netlink.Route:
+						// Filter out invalid routes.
+						if r == nil || r.Dst == nil {
+							continue
+						}
+						// If any deleted route's destination matches a destination
+						// in the table, reset the corresponding route just in case.
+						if r.Dst.IP.Equal(e.Route.Dst.IP) && r.Dst.Mask.String() == e.Route.Dst.Mask.String() {
+							if err := t.addRoute(r); err != nil {
+								nonBlockingSend(t.errors, fmt.Errorf("failed add route: %v", err))
+							}
 						}
 					}
 				}
@@ -110,46 +132,66 @@ func (t *Table) Run(stop <-chan struct{}) (<-chan error, error) {
 	return t.errors, nil
 }
 
-// CleanUp will clean up any routes created by the instance.
+// CleanUp will clean up any routes and rules created by the instance.
 func (t *Table) CleanUp() error {
 	t.mu.Lock()
 	defer t.mu.Unlock()
-	for k, route := range t.routes {
-		if err := t.del(route); err != nil {
-			return fmt.Errorf("failed to delete route: %v", err)
+	for k := range t.rs {
+		switch r := t.rs[k].(type) {
+		case *netlink.Route:
+			if err := t.delRoute(r); err != nil {
+				return fmt.Errorf("failed to delete route: %v", err)
+			}
+		case *netlink.Rule:
+			if err := t.delRule(r); err != nil {
+				return fmt.Errorf("failed to delete rule: %v", err)
+			}
 		}
-		delete(t.routes, k)
+		delete(t.rs, k)
 	}
 	return nil
 }
 
-// Set idempotently overwrites any routes previously defined
-// for the table with the given set of routes.
-func (t *Table) Set(routes []*netlink.Route) error {
-	r := make(map[string]*netlink.Route)
+// Set idempotently overwrites any routes and rules previously defined
+// for the table with the given set of routes and rules.
+func (t *Table) Set(routes []*netlink.Route, rules []*netlink.Rule) error {
+	rs := make(map[string]interface{})
 	for _, route := range routes {
 		if route == nil {
 			continue
 		}
-		r[routeToString(route)] = route
+		rs[routeToString(route)] = route
+	}
+	for _, rule := range rules {
+		if rule == nil {
+			continue
+		}
+		rs[ruleToString(rule)] = rule
 	}
 	t.mu.Lock()
 	defer t.mu.Unlock()
-	for k := range t.routes {
-		if _, ok := r[k]; !ok {
-			if err := t.del(t.routes[k]); err != nil {
-				return fmt.Errorf("failed to delete route: %v", err)
+	for k := range t.rs {
+		if _, ok := rs[k]; !ok {
+			switch r := t.rs[k].(type) {
+			case *netlink.Route:
+				if err := t.delRoute(r); err != nil {
+					return fmt.Errorf("failed to delete route: %v", err)
+				}
+			case *netlink.Rule:
+				if err := t.delRule(r); err != nil {
+					return fmt.Errorf("failed to delete rule: %v", err)
+				}
 			}
-			delete(t.routes, k)
+			delete(t.rs, k)
 		}
 	}
 
-	// When adding routes, we need to compare against what is
+	// When adding routes/rules, we need to compare against what is
 	// actually on the Linux routing table. This is because
-	// routes can be deleted by the kernel due to interface churn
-	// causing a situation where the controller thinks it has a route
+	// routes/rules can be deleted by the kernel due to interface churn
+	// causing a situation where the controller thinks it has an item
 	// that is not actually there.
-	existing := make(map[string]*netlink.Route)
+	existing := make(map[string]interface{})
 	existingRoutes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
 	if err != nil {
 		return fmt.Errorf("failed to list existing routes: %v", err)
@@ -158,12 +200,27 @@ func (t *Table) Set(routes []*netlink.Route) error {
 		existing[routeToString(&existingRoutes[k])] = &existingRoutes[k]
 	}
 
-	for k := range r {
+	existingRules, err := netlink.RuleList(netlink.FAMILY_ALL)
+	if err != nil {
+		return fmt.Errorf("failed to list existing rules: %v", err)
+	}
+	for k := range existingRules {
+		existing[ruleToString(&existingRules[k])] = &existingRules[k]
+	}
+
+	for k := range rs {
 		if _, ok := existing[k]; !ok {
-			if err := t.add(r[k]); err != nil {
-				return fmt.Errorf("failed to add route %q: %v", routeToString(r[k]), err)
+			switch r := rs[k].(type) {
+			case *netlink.Route:
+				if err := t.addRoute(r); err != nil {
+					return fmt.Errorf("failed to add route %q: %v", k, err)
+				}
+			case *netlink.Rule:
+				if err := t.addRule(r); err != nil {
+					return fmt.Errorf("failed to add rule %q: %v", k, err)
+				}
 			}
-			t.routes[k] = r[k]
+			t.rs[k] = rs[k]
 		}
 	}
 	return nil
@@ -190,3 +247,18 @@ func routeToString(route *netlink.Route) string {
 	}
 	return fmt.Sprintf("dst: %s, via: %s, src: %s, dev: %d", route.Dst.String(), gw, src, route.LinkIndex)
 }
+
+func ruleToString(rule *netlink.Rule) string {
+	if rule == nil || (rule.Src == nil && rule.Dst == nil) {
+		return ""
+	}
+	src := "-"
+	if rule.Src != nil {
+		src = rule.Src.String()
+	}
+	dst := "-"
+	if rule.Dst != nil {
+		dst = rule.Dst.String()
+	}
+	return fmt.Sprintf("src: %s, dst: %s, table: %d, input: %s", src, dst, rule.Table, rule.IifName)
+}

+ 188 - 69
pkg/route/route_test.go

@@ -31,36 +31,54 @@ func TestSet(t *testing.T) {
 	if err != nil {
 		t.Fatalf("failed to parse CIDR: %v", err)
 	}
-	add := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
+	addRoute := func(backend map[string]interface{}) func(*netlink.Route) error {
 		return func(r *netlink.Route) error {
 			backend[routeToString(r)] = r
 			return nil
 		}
 	}
-	del := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
+	delRoute := func(backend map[string]interface{}) func(*netlink.Route) error {
 		return func(r *netlink.Route) error {
 			delete(backend, routeToString(r))
 			return nil
 		}
 	}
-	adderr := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
+	addRule := func(backend map[string]interface{}) func(*netlink.Rule) error {
+		return func(r *netlink.Rule) error {
+			backend[ruleToString(r)] = r
+			return nil
+		}
+	}
+	delRule := func(backend map[string]interface{}) func(*netlink.Rule) error {
+		return func(r *netlink.Rule) error {
+			delete(backend, ruleToString(r))
+			return nil
+		}
+	}
+	adderr := func(backend map[string]interface{}) func(*netlink.Route) error {
 		return func(r *netlink.Route) error {
 			return errors.New(routeToString(r))
 		}
 	}
 	for _, tc := range []struct {
-		name   string
-		routes []*netlink.Route
-		err    bool
-		add    func(map[string]*netlink.Route) func(*netlink.Route) error
-		del    func(map[string]*netlink.Route) func(*netlink.Route) error
+		name     string
+		routes   []*netlink.Route
+		rules    []*netlink.Rule
+		err      bool
+		addRoute func(map[string]interface{}) func(*netlink.Route) error
+		delRoute func(map[string]interface{}) func(*netlink.Route) error
+		addRule  func(map[string]interface{}) func(*netlink.Rule) error
+		delRule  func(map[string]interface{}) func(*netlink.Rule) error
 	}{
 		{
-			name:   "empty",
-			routes: nil,
-			err:    false,
-			add:    add,
-			del:    del,
+			name:     "empty",
+			routes:   nil,
+			rules:    nil,
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
 			name: "single",
@@ -70,9 +88,17 @@ func TestSet(t *testing.T) {
 					Gw:  net.ParseIP("10.1.0.1"),
 				},
 			},
-			err: false,
-			add: add,
-			del: del,
+			rules: []*netlink.Rule{
+				{
+					Src:   c1,
+					Table: 1,
+				},
+			},
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
 			name: "multiple",
@@ -86,16 +112,30 @@ func TestSet(t *testing.T) {
 					Gw:  net.ParseIP("127.0.0.1"),
 				},
 			},
-			err: false,
-			add: add,
-			del: del,
+			rules: []*netlink.Rule{
+				{
+					Src:   c1,
+					Table: 1,
+				},
+				{
+					Src:   c2,
+					Table: 2,
+				},
+			},
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
-			name:   "err empty",
-			routes: nil,
-			err:    false,
-			add:    adderr,
-			del:    del,
+			name:     "err empty",
+			routes:   nil,
+			err:      false,
+			addRoute: adderr,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
 			name: "err",
@@ -109,18 +149,30 @@ func TestSet(t *testing.T) {
 					Gw:  net.ParseIP("127.0.0.1"),
 				},
 			},
-			err: true,
-			add: adderr,
-			del: del,
+			rules: []*netlink.Rule{
+				{
+					Src:   c1,
+					Table: 1,
+				},
+				{
+					Src:   c2,
+					Table: 2,
+				},
+			},
+			err:      true,
+			addRoute: adderr,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 	} {
-		backend := make(map[string]*netlink.Route)
-		a := tc.add(backend)
-		d := tc.del(backend)
+		backend := make(map[string]interface{})
 		table := NewTable()
-		table.add = a
-		table.del = d
-		if err := table.Set(tc.routes); (err != nil) != tc.err {
+		table.addRoute = tc.addRoute(backend)
+		table.delRoute = tc.delRoute(backend)
+		table.addRule = tc.addRule(backend)
+		table.delRule = tc.delRule(backend)
+		if err := table.Set(tc.routes, tc.rules); (err != nil) != tc.err {
 			no := "no"
 			if tc.err {
 				no = "an"
@@ -131,11 +183,18 @@ func TestSet(t *testing.T) {
 		if !tc.err {
 			for _, r := range tc.routes {
 				r1 := backend[routeToString(r)]
-				r2 := table.routes[routeToString(r)]
+				r2 := table.rs[routeToString(r)]
 				if r != r1 || r != r2 {
 					t.Errorf("test case %q: expected all routes to be equal: expected %v, got %v and %v", tc.name, r, r1, r2)
 				}
 			}
+			for _, r := range tc.rules {
+				r1 := backend[ruleToString(r)]
+				r2 := table.rs[ruleToString(r)]
+				if r != r1 || r != r2 {
+					t.Errorf("test case %q: expected all rules to be equal: expected %v, got %v and %v", tc.name, r, r1, r2)
+				}
+			}
 		}
 	}
 }
@@ -149,36 +208,53 @@ func TestCleanUp(t *testing.T) {
 	if err != nil {
 		t.Fatalf("failed to parse CIDR: %v", err)
 	}
-	add := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
+	addRoute := func(backend map[string]interface{}) func(*netlink.Route) error {
 		return func(r *netlink.Route) error {
 			backend[routeToString(r)] = r
 			return nil
 		}
 	}
-	del := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
+	delRoute := func(backend map[string]interface{}) func(*netlink.Route) error {
 		return func(r *netlink.Route) error {
 			delete(backend, routeToString(r))
 			return nil
 		}
 	}
-	delerr := func(backend map[string]*netlink.Route) func(*netlink.Route) error {
+	addRule := func(backend map[string]interface{}) func(*netlink.Rule) error {
+		return func(r *netlink.Rule) error {
+			backend[ruleToString(r)] = r
+			return nil
+		}
+	}
+	delRule := func(backend map[string]interface{}) func(*netlink.Rule) error {
+		return func(r *netlink.Rule) error {
+			delete(backend, ruleToString(r))
+			return nil
+		}
+	}
+	delerr := func(backend map[string]interface{}) func(*netlink.Route) error {
 		return func(r *netlink.Route) error {
 			return errors.New(routeToString(r))
 		}
 	}
 	for _, tc := range []struct {
-		name   string
-		routes []*netlink.Route
-		err    bool
-		add    func(map[string]*netlink.Route) func(*netlink.Route) error
-		del    func(map[string]*netlink.Route) func(*netlink.Route) error
+		name     string
+		routes   []*netlink.Route
+		rules    []*netlink.Rule
+		err      bool
+		addRoute func(map[string]interface{}) func(*netlink.Route) error
+		delRoute func(map[string]interface{}) func(*netlink.Route) error
+		addRule  func(map[string]interface{}) func(*netlink.Rule) error
+		delRule  func(map[string]interface{}) func(*netlink.Rule) error
 	}{
 		{
-			name:   "empty",
-			routes: nil,
-			err:    false,
-			add:    add,
-			del:    del,
+			name:     "empty",
+			routes:   nil,
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
 			name: "single",
@@ -188,9 +264,17 @@ func TestCleanUp(t *testing.T) {
 					Gw:  net.ParseIP("10.1.0.1"),
 				},
 			},
-			err: false,
-			add: add,
-			del: del,
+			rules: []*netlink.Rule{
+				{
+					Src:   c1,
+					Table: 1,
+				},
+			},
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
 			name: "multiple",
@@ -204,16 +288,30 @@ func TestCleanUp(t *testing.T) {
 					Gw:  net.ParseIP("127.0.0.1"),
 				},
 			},
-			err: false,
-			add: add,
-			del: del,
+			rules: []*netlink.Rule{
+				{
+					Src:   c1,
+					Table: 1,
+				},
+				{
+					Src:   c2,
+					Table: 2,
+				},
+			},
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
-			name:   "err empty",
-			routes: nil,
-			err:    false,
-			add:    add,
-			del:    delerr,
+			name:     "err empty",
+			routes:   nil,
+			err:      false,
+			addRoute: addRoute,
+			delRoute: delRoute,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 		{
 			name: "err",
@@ -227,18 +325,30 @@ func TestCleanUp(t *testing.T) {
 					Gw:  net.ParseIP("127.0.0.1"),
 				},
 			},
-			err: true,
-			add: add,
-			del: delerr,
+			rules: []*netlink.Rule{
+				{
+					Src:   c1,
+					Table: 1,
+				},
+				{
+					Src:   c2,
+					Table: 2,
+				},
+			},
+			err:      true,
+			addRoute: addRoute,
+			delRoute: delerr,
+			addRule:  addRule,
+			delRule:  delRule,
 		},
 	} {
-		backend := make(map[string]*netlink.Route)
-		a := tc.add(backend)
-		d := tc.del(backend)
+		backend := make(map[string]interface{})
 		table := NewTable()
-		table.add = a
-		table.del = d
-		if err := table.Set(tc.routes); err != nil {
+		table.addRoute = tc.addRoute(backend)
+		table.delRoute = tc.delRoute(backend)
+		table.addRule = tc.addRule(backend)
+		table.delRule = tc.delRule(backend)
+		if err := table.Set(tc.routes, tc.rules); err != nil {
 			t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
 		}
 		if err := table.CleanUp(); (err != nil) != tc.err {
@@ -252,9 +362,18 @@ func TestCleanUp(t *testing.T) {
 		if !tc.err {
 			for _, r := range tc.routes {
 				r1 := backend[routeToString(r)]
-				r2 := table.routes[routeToString(r)]
+				r2 := table.rs[routeToString(r)]
+				if r1 != nil || r2 != nil {
+					t.Errorf("test case %q: expected all routes to be nil: expected nil, got %v and %v", tc.name, r1, r2)
+				}
+			}
+		}
+		if !tc.err {
+			for _, r := range tc.rules {
+				r1 := backend[ruleToString(r)]
+				r2 := table.rs[ruleToString(r)]
 				if r1 != nil || r2 != nil {
-					t.Errorf("test case %q: expected all routes to be nil: expected got %v and %v", tc.name, r1, r2)
+					t.Errorf("test case %q: expected all rules to be nil: expected nil, got %v and %v", tc.name, r1, r2)
 				}
 			}
 		}