Просмотр исходного кода

pkg/mesh: avoid NAT-ing packets to service CIDRs

Currently, packets to service CIDRs may be masqueraded because they are
IP addresses that Kilo does not know about and therefore is not sure if
they know about some Kilo IPs, e.g. Peer IPs. This is not terrible but
it is annoying and can prevent some advanced use-cases, see #330. This
commit adds an optional flag to the `kg` binary that can be given
multiple times to specify the service CIDRs of the cluster so that Kilo
does not masquerade packets to them.

Signed-off-by: Lucas Servén Marín <lserven@gmail.com>
Lucas Servén Marín 3 лет назад
Родитель
Сommit
12ccb65edb

+ 6 - 5
cmd/kg/handlers.go

@@ -30,10 +30,11 @@ import (
 )
 
 type graphHandler struct {
-	mesh        *mesh.Mesh
-	granularity mesh.Granularity
-	hostname    *string
-	subnet      *net.IPNet
+	mesh         *mesh.Mesh
+	granularity  mesh.Granularity
+	hostname     *string
+	subnet       *net.IPNet
+	serviceCIDRs []*net.IPNet
 }
 
 func (h *graphHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -64,7 +65,7 @@ func (h *graphHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			peers[p.Name] = p
 		}
 	}
-	topo, err := mesh.NewTopology(nodes, peers, h.granularity, *h.hostname, 0, wgtypes.Key{}, h.subnet, nodes[*h.hostname].PersistentKeepalive, nil)
+	topo, err := mesh.NewTopology(nodes, peers, h.granularity, *h.hostname, 0, wgtypes.Key{}, h.subnet, h.serviceCIDRs, nodes[*h.hostname].PersistentKeepalive, nil)
 	if err != nil {
 		http.Error(w, fmt.Sprintf("failed to create topology: %v", err), http.StatusInternalServerError)
 		return

+ 14 - 2
cmd/kg/main.go

@@ -111,6 +111,7 @@ var (
 	mtu                   uint
 	topologyLabel         string
 	port                  int
+	serviceCIDRsRaw       []string
 	subnet                string
 	resyncPeriod          time.Duration
 	iptablesForwardRule   bool
@@ -141,6 +142,7 @@ func init() {
 	cmd.Flags().UintVar(&mtu, "mtu", wireguard.DefaultMTU, "The MTU of the WireGuard interface created by Kilo.")
 	cmd.Flags().StringVar(&topologyLabel, "topology-label", k8s.RegionLabelKey, "Kubernetes node label used to group nodes into logical locations.")
 	cmd.Flags().IntVar(&port, "port", mesh.DefaultKiloPort, "The port over which WireGuard peers should communicate.")
+	cmd.Flags().StringSliceVar(&serviceCIDRsRaw, "service-cidr", nil, "The service CIDR for the Kubernetes cluster. Can be provided optionally to avoid masquerading packets sent to service IPs. Can be specified multiple times.")
 	cmd.Flags().StringVar(&subnet, "subnet", mesh.DefaultKiloSubnet.String(), "CIDR from which to allocate addresses for WireGuard interfaces.")
 	cmd.Flags().DurationVar(&resyncPeriod, "resync-period", 30*time.Second, "How often should the Kilo controllers reconcile?")
 	cmd.Flags().BoolVar(&iptablesForwardRule, "iptables-forward-rules", false, "Add default accept rules to the FORWARD chain in iptables. Warning: this may break firewalls with a deny all policy and is potentially insecure!")
@@ -245,7 +247,17 @@ func runRoot(_ *cobra.Command, _ []string) error {
 	if port < 1 || port > 1<<16-1 {
 		return fmt.Errorf("invalid port: port mus be in range [%d:%d], but got %d", 1, 1<<16-1, port)
 	}
-	m, err := mesh.New(b, enc, gr, hostname, port, s, local, cni, cniPath, iface, cleanUpIface, createIface, mtu, resyncPeriod, prioritisePrivateAddr, iptablesForwardRule, log.With(logger, "component", "kilo"), registry)
+
+	var serviceCIDRs []*net.IPNet
+	for _, serviceCIDR := range serviceCIDRsRaw {
+		_, s, err := net.ParseCIDR(serviceCIDR)
+		if err != nil {
+			return fmt.Errorf("failed to parse %q as CIDR: %v", serviceCIDR, err)
+		}
+		serviceCIDRs = append(serviceCIDRs, s)
+	}
+
+	m, err := mesh.New(b, enc, gr, hostname, port, s, local, cni, cniPath, iface, cleanUpIface, createIface, mtu, resyncPeriod, prioritisePrivateAddr, iptablesForwardRule, serviceCIDRs, log.With(logger, "component", "kilo"), registry)
 	if err != nil {
 		return fmt.Errorf("failed to create Kilo mesh: %v", err)
 	}
@@ -258,7 +270,7 @@ func runRoot(_ *cobra.Command, _ []string) error {
 			internalserver.WithPProf(),
 		)
 		h.AddEndpoint("/health", "Exposes health checks", healthHandler)
-		h.AddEndpoint("/graph", "Exposes Kilo mesh topology graph", (&graphHandler{m, gr, &hostname, s}).ServeHTTP)
+		h.AddEndpoint("/graph", "Exposes Kilo mesh topology graph", (&graphHandler{m, gr, &hostname, s, serviceCIDRs}).ServeHTTP)
 		// Run the HTTP server.
 		l, err := net.Listen("tcp", listen)
 		if err != nil {

+ 1 - 1
cmd/kgctl/connect_linux.go

@@ -330,7 +330,7 @@ func sync(table *route.Table, peerName string, privateKey wgtypes.Key, iface int
 		return fmt.Errorf("did not find any peer named %q in the cluster", peerName)
 	}
 
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, wgtypes.Key{}, subnet, *peers[peerName].PersistentKeepaliveInterval, logger)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, opts.port, wgtypes.Key{}, subnet, nil, *peers[peerName].PersistentKeepaliveInterval, logger)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %w", err)
 	}

+ 1 - 1
cmd/kgctl/graph.go

@@ -67,7 +67,7 @@ func runGraph(_ *cobra.Command, _ []string) error {
 			peers[p.Name] = p
 		}
 	}
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, wgtypes.Key{}, subnet, nodes[hostname].PersistentKeepalive, nil)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, 0, wgtypes.Key{}, subnet, nil, nodes[hostname].PersistentKeepalive, nil)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %w", err)
 	}

+ 2 - 2
cmd/kgctl/showconf.go

@@ -152,7 +152,7 @@ func runShowConfNode(_ *cobra.Command, args []string) error {
 		}
 	}
 
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, int(opts.port), wgtypes.Key{}, subnet, nodes[hostname].PersistentKeepalive, nil)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, int(opts.port), wgtypes.Key{}, subnet, nil, nodes[hostname].PersistentKeepalive, nil)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %w", err)
 	}
@@ -255,7 +255,7 @@ func runShowConfPeer(_ *cobra.Command, args []string) error {
 	if p := peers[peer].PersistentKeepaliveInterval; p != nil {
 		pka = *p
 	}
-	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, wgtypes.Key{}, subnet, pka, nil)
+	t, err := mesh.NewTopology(nodes, peers, opts.granularity, hostname, mesh.DefaultKiloPort, wgtypes.Key{}, subnet, nil, pka, nil)
 	if err != nil {
 		return fmt.Errorf("failed to create topology: %w", err)
 	}

+ 4 - 2
pkg/mesh/mesh.go

@@ -68,6 +68,7 @@ type Mesh struct {
 	pub                 wgtypes.Key
 	resyncPeriod        time.Duration
 	iptablesForwardRule bool
+	serviceCIDRs        []*net.IPNet
 	subnet              *net.IPNet
 	table               *route.Table
 	wireGuardIP         *net.IPNet
@@ -87,7 +88,7 @@ type Mesh struct {
 }
 
 // New returns a new Mesh instance.
-func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, logger log.Logger, registerer prometheus.Registerer) (*Mesh, error) {
+func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularity, hostname string, port int, subnet *net.IPNet, local, cni bool, cniPath, iface string, cleanUpIface bool, createIface bool, mtu uint, resyncPeriod time.Duration, prioritisePrivateAddr, iptablesForwardRule bool, serviceCIDRs []*net.IPNet, logger log.Logger, registerer prometheus.Registerer) (*Mesh, error) {
 	if err := os.MkdirAll(kiloPath, 0700); err != nil {
 		return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
 	}
@@ -181,6 +182,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
 		resyncPeriod:        resyncPeriod,
 		iptablesForwardRule: iptablesForwardRule,
 		local:               local,
+		serviceCIDRs:        serviceCIDRs,
 		subnet:              subnet,
 		table:               route.NewTable(),
 		errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
@@ -494,7 +496,7 @@ func (m *Mesh) applyTopology() {
 
 	natEndpoints := discoverNATEndpoints(nodes, peers, wgDevice, m.logger)
 	nodes[m.hostname].DiscoveredEndpoints = natEndpoints
-	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
+	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port(), m.priv, m.subnet, m.serviceCIDRs, nodes[m.hostname].PersistentKeepalive, m.logger)
 	if err != nil {
 		level.Error(m.logger).Log("error", err)
 		m.errorCounter.WithLabelValues("apply").Inc()

+ 3 - 0
pkg/mesh/routes.go

@@ -370,6 +370,9 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
 			)
 		}
 	}
+	for _, s := range t.serviceCIDRs {
+		rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.IP), "nat", "KILO-NAT", "-d", s.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for service CIDRs", "-j", "RETURN"))
+	}
 	rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
 	rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
 	return rules

+ 6 - 1
pkg/mesh/topology.go

@@ -60,6 +60,10 @@ type Topology struct {
 	// the IP is the 0th address in the subnet, i.e. the CIDR
 	// is equal to the Kilo subnet.
 	wireGuardCIDR *net.IPNet
+	// serviceCIDRs are the known service CIDRs of the Kubernetes cluster.
+	// They are not strictly needed, however if they are known,
+	// then the topology can avoid masquerading packets destined to service IPs.
+	serviceCIDRs []*net.IPNet
 	// discoveredEndpoints is the updated map of valid discovered Endpoints
 	discoveredEndpoints map[string]*net.UDPAddr
 	logger              log.Logger
@@ -92,7 +96,7 @@ type segment struct {
 }
 
 // NewTopology creates a new Topology struct from a given set of nodes and peers.
-func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
+func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, serviceCIDRs []*net.IPNet, persistentKeepalive time.Duration, logger log.Logger) (*Topology, error) {
 	if logger == nil {
 		logger = log.NewNopLogger()
 	}
@@ -132,6 +136,7 @@ func NewTopology(nodes map[string]*Node, peers map[string]*Peer, granularity Gra
 		privateIP:           nodes[hostname].InternalIP,
 		subnet:              nodes[hostname].Subnet,
 		wireGuardCIDR:       subnet,
+		serviceCIDRs:        serviceCIDRs,
 		discoveredEndpoints: make(map[string]*net.UDPAddr),
 		logger:              logger,
 	}

+ 2 - 2
pkg/mesh/topology_test.go

@@ -535,7 +535,7 @@ func TestNewTopology(t *testing.T) {
 	} {
 		tc.result.key = key
 		tc.result.port = port
-		topo, err := NewTopology(nodes, peers, tc.granularity, tc.hostname, port, key, DefaultKiloSubnet, 0, nil)
+		topo, err := NewTopology(nodes, peers, tc.granularity, tc.hostname, port, key, DefaultKiloSubnet, nil, 0, nil)
 		if err != nil {
 			t.Errorf("test case %q: failed to generate Topology: %v", tc.name, err)
 		}
@@ -546,7 +546,7 @@ func TestNewTopology(t *testing.T) {
 }
 
 func mustTopo(t *testing.T, nodes map[string]*Node, peers map[string]*Peer, granularity Granularity, hostname string, port int, key wgtypes.Key, subnet *net.IPNet, persistentKeepalive time.Duration) *Topology {
-	topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, persistentKeepalive, nil)
+	topo, err := NewTopology(nodes, peers, granularity, hostname, port, key, subnet, nil, persistentKeepalive, nil)
 	if err != nil {
 		t.Errorf("failed to generate Topology: %v", err)
 	}