ソースを参照

pkg/{encapsulation,mesh}: abstract encapsulation

This commit abstracts away encapsulation to more easily allow for
different types of encapsulation or compatibility with other networking
solutions.
Lucas Servén Marín 7 年 前
コミット
cd6eeeb1e7

+ 12 - 9
cmd/kg/main.go

@@ -34,6 +34,7 @@ import (
 	"k8s.io/client-go/kubernetes"
 	"k8s.io/client-go/tools/clientcmd"
 
+	"github.com/squat/kilo/pkg/encapsulation"
 	"github.com/squat/kilo/pkg/k8s"
 	kiloclient "github.com/squat/kilo/pkg/k8s/clientset/versioned"
 	"github.com/squat/kilo/pkg/mesh"
@@ -54,9 +55,9 @@ var (
 		k8s.Backend,
 	}, ", ")
 	availableEncapsulations = strings.Join([]string{
-		string(mesh.NeverEncapsulate),
-		string(mesh.CrossSubnetEncapsulate),
-		string(mesh.AlwaysEncapsulate),
+		string(encapsulation.Never),
+		string(encapsulation.CrossSubnet),
+		string(encapsulation.Always),
 	}, ", ")
 	availableGranularities = strings.Join([]string{
 		string(mesh.LogicalGranularity),
@@ -77,7 +78,7 @@ func Main() error {
 	backend := flag.String("backend", k8s.Backend, fmt.Sprintf("The backend for the mesh. Possible values: %s", availableBackends))
 	cni := flag.Bool("cni", true, "Should Kilo manage the node's CNI configuration.")
 	cniPath := flag.String("cni-path", mesh.DefaultCNIPath, "Path to CNI config.")
-	encapsulate := flag.String("encapsulate", string(mesh.AlwaysEncapsulate), fmt.Sprintf("When should Kilo encapsulate packets within a location. Possible values: %s", availableEncapsulations))
+	encapsulate := flag.String("encapsulate", string(encapsulation.Always), fmt.Sprintf("When should Kilo encapsulate packets within a location. Possible values: %s", availableEncapsulations))
 	granularity := flag.String("mesh-granularity", string(mesh.LogicalGranularity), fmt.Sprintf("The granularity of the network mesh to create. Possible values: %s", availableGranularities))
 	kubeconfig := flag.String("kubeconfig", "", "Path to kubeconfig.")
 	hostname := flag.String("hostname", "", "Hostname of the node on which this process is running.")
@@ -129,14 +130,16 @@ func Main() error {
 	logger = log.With(logger, "ts", log.DefaultTimestampUTC)
 	logger = log.With(logger, "caller", log.DefaultCaller)
 
-	e := mesh.Encapsulate(*encapsulate)
+	var enc encapsulation.Interface
+	e := encapsulation.Strategy(*encapsulate)
 	switch e {
-	case mesh.NeverEncapsulate:
-	case mesh.CrossSubnetEncapsulate:
-	case mesh.AlwaysEncapsulate:
+	case encapsulation.Never:
+	case encapsulation.CrossSubnet:
+	case encapsulation.Always:
 	default:
 		return fmt.Errorf("encapsulation %v unknown; possible values are: %s", *encapsulate, availableEncapsulations)
 	}
+	enc = encapsulation.NewIPIP(e)
 
 	gr := mesh.Granularity(*granularity)
 	switch gr {
@@ -161,7 +164,7 @@ func Main() error {
 		return fmt.Errorf("backend %v unknown; possible values are: %s", *backend, availableBackends)
 	}
 
-	m, err := mesh.New(b, e, gr, *hostname, uint32(port), s, *local, *cni, *cniPath, log.With(logger, "component", "kilo"))
+	m, err := mesh.New(b, enc, gr, *hostname, uint32(port), s, *local, *cni, *cniPath, log.With(logger, "component", "kilo"))
 	if err != nil {
 		return fmt.Errorf("failed to create Kilo mesh: %v", err)
 	}

+ 107 - 0
pkg/encapsulation/ipip.go

@@ -0,0 +1,107 @@
+// Copyright 2019 the Kilo authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package encapsulation
+
+import (
+	"fmt"
+	"net"
+
+	"github.com/squat/kilo/pkg/iproute"
+	"github.com/squat/kilo/pkg/iptables"
+)
+
+// Strategy identifies which packets within a location should
+// be encapsulated.
+type Strategy string
+
+const (
+	// Never indicates that no packets within a location
+	// should be encapsulated.
+	Never Strategy = "never"
+	// CrossSubnet indicates that only packets that
+	// traverse subnets within a location should be encapsulated.
+	CrossSubnet Strategy = "crosssubnet"
+	// Always indicates that all packets within a location
+	// should be encapsulated.
+	Always Strategy = "always"
+)
+
+// Interface can configure
+// the encapsulation interface, init itself,
+// get the encapsulation interface index,
+// set the interface IP address,
+// return the required IPTables rules,
+// return the encapsulation strategy,
+// and clean up any changes applied to the backend.
+type Interface interface {
+	CleanUp() error
+	Index() int
+	Init(int) error
+	Rules([]*net.IPNet) []iptables.Rule
+	Set(*net.IPNet) error
+	Strategy() Strategy
+}
+
+type ipip struct {
+	iface    int
+	strategy Strategy
+}
+
+// NewIPIP returns an encapsulation that uses IPIP.
+func NewIPIP(strategy Strategy) Interface {
+	return &ipip{strategy: strategy}
+}
+
+// CleanUp will remove any created IPIP devices.
+func (i *ipip) CleanUp() error {
+	if err := iproute.DeleteAddresses(i.iface); err != nil {
+		return nil
+	}
+	return iproute.RemoveInterface(i.iface)
+}
+
+// Index returns the index of the IPIP interface.
+func (i *ipip) Index() int {
+	return i.iface
+}
+
+// Init initializes the IPIP interface.
+func (i *ipip) Init(base int) error {
+	iface, err := iproute.NewIPIP(base)
+	if err != nil {
+		return fmt.Errorf("failed to create tunnel interface: %v", err)
+	}
+	if err := iproute.Set(iface, true); err != nil {
+		return fmt.Errorf("failed to set tunnel interface up: %v", err)
+	}
+	i.iface = iface
+	return nil
+}
+
+// Rules returns a set of iptables rules that are necessary
+// when traffic between nodes must be encapsulated.
+func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
+	return iptables.IPIPRules(nodes)
+}
+
+// Set sets the IP address of the IPIP interface.
+func (i *ipip) Set(cidr *net.IPNet) error {
+	return iproute.SetAddress(i.iface, cidr)
+}
+
+// Strategy returns the configured strategy for encapsulation.
+func (i *ipip) Strategy() Strategy {
+	return i.strategy
+}

+ 18 - 0
pkg/iproute/iproute.go

@@ -68,3 +68,21 @@ func SetAddress(index int, cidr *net.IPNet) error {
 	}
 	return netlink.AddrReplace(link, &netlink.Addr{IPNet: cidr})
 }
+
+// DeleteAddresses removes all IP addresses of an interface.
+func DeleteAddresses(index int) error {
+	link, err := netlink.LinkByIndex(index)
+	if err != nil {
+		return fmt.Errorf("failed to get link: %s", err)
+	}
+	addrs, err := netlink.AddrList(link, netlink.FAMILY_ALL)
+	if err != nil {
+		return err
+	}
+	for _, addr := range addrs {
+		if err := netlink.AddrDel(link, &addr); err != nil {
+			return fmt.Errorf("failed to delete address: %s", err)
+		}
+	}
+	return nil
+}

+ 3 - 3
pkg/iptables/iptables.go

@@ -241,9 +241,9 @@ func (c *Controller) CleanUp() error {
 	return nil
 }
 
-// EncapsulateRules returns a set of iptables rules that are necessary
-// when traffic between nodes must be encapsulated.
-func EncapsulateRules(nodes []*net.IPNet) []Rule {
+// IPIPRules returns a set of iptables rules that are necessary
+// when traffic between nodes must be encapsulated with IPIP.
+func IPIPRules(nodes []*net.IPNet) []Rule {
 	var rules []Rule
 	rules = append(rules, &chain{"filter", "KILO-IPIP", nil})
 	rules = append(rules, &rule{"filter", "INPUT", []string{"-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-p", "4", "-j", "KILO-IPIP"}, nil})

+ 23 - 30
pkg/mesh/mesh.go

@@ -28,6 +28,7 @@ import (
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/vishvananda/netlink"
 
+	"github.com/squat/kilo/pkg/encapsulation"
 	"github.com/squat/kilo/pkg/iproute"
 	"github.com/squat/kilo/pkg/iptables"
 	"github.com/squat/kilo/pkg/route"
@@ -56,10 +57,6 @@ var DefaultKiloSubnet = &net.IPNet{IP: []byte{10, 4, 0, 0}, Mask: []byte{255, 25
 // should be meshed.
 type Granularity string
 
-// Encapsulate identifies what packets within a location should
-// be encapsulated.
-type Encapsulate string
-
 const (
 	// LogicalGranularity indicates that the network should create
 	// a mesh between logical locations, e.g. data-centers, but not between
@@ -68,15 +65,6 @@ const (
 	// FullGranularity indicates that the network should create
 	// a mesh between every node.
 	FullGranularity Granularity = "full"
-	// NeverEncapsulate indicates that no packets within a location
-	// should be encapsulated.
-	NeverEncapsulate Encapsulate = "never"
-	// CrossSubnetEncapsulate indicates that only packets that
-	// traverse subnets within a location should be encapsulated.
-	CrossSubnetEncapsulate Encapsulate = "crosssubnet"
-	// AlwaysEncapsulate indicates that all packets within a location
-	// should be encapsulated.
-	AlwaysEncapsulate Encapsulate = "always"
 )
 
 // Node represents a node in the network.
@@ -181,7 +169,7 @@ type Mesh struct {
 	Backend
 	cni         bool
 	cniPath     string
-	encapsulate Encapsulate
+	enc         encapsulation.Interface
 	externalIP  *net.IPNet
 	granularity Granularity
 	hostname    string
@@ -198,7 +186,6 @@ type Mesh struct {
 	stop        chan struct{}
 	subnet      *net.IPNet
 	table       *route.Table
-	tunlIface   int
 	wireGuardIP *net.IPNet
 
 	// nodes and peers are mutable fields in the struct
@@ -215,7 +202,7 @@ type Mesh struct {
 }
 
 // New returns a new Mesh instance.
-func New(backend Backend, encapsulate Encapsulate, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath string, logger log.Logger) (*Mesh, error) {
+func New(backend Backend, enc encapsulation.Interface, granularity Granularity, hostname string, port uint32, subnet *net.IPNet, local, cni bool, cniPath string, logger log.Logger) (*Mesh, error) {
 	if err := os.MkdirAll(KiloPath, 0700); err != nil {
 		return nil, fmt.Errorf("failed to create directory to store configuration: %v", err)
 	}
@@ -238,7 +225,7 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
 	if err != nil {
 		return nil, fmt.Errorf("failed to query netlink for CNI device: %v", err)
 	}
-	privateIP, publicIP, err := getIP(hostname, cniIndex)
+	privateIP, publicIP, err := getIP(hostname, enc.Index(), cniIndex)
 	if err != nil {
 		return nil, fmt.Errorf("failed to find public IP: %v", err)
 	}
@@ -256,13 +243,9 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
 	if err != nil {
 		return nil, fmt.Errorf("failed to create WireGuard interface: %v", err)
 	}
-	var tunlIface int
-	if encapsulate != NeverEncapsulate {
-		if tunlIface, err = iproute.NewIPIP(privIface); err != nil {
-			return nil, fmt.Errorf("failed to create tunnel interface: %v", err)
-		}
-		if err := iproute.Set(tunlIface, true); err != nil {
-			return nil, fmt.Errorf("failed to set tunnel interface up: %v", err)
+	if enc.Strategy() != encapsulation.Never {
+		if err := enc.Init(privIface); err != nil {
+			return nil, fmt.Errorf("failed to initialize encapsulation: %v", err)
 		}
 	}
 	level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the private IP address", privateIP.String()))
@@ -275,7 +258,7 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
 		Backend:     backend,
 		cni:         cni,
 		cniPath:     cniPath,
-		encapsulate: encapsulate,
+		enc:         enc,
 		externalIP:  publicIP,
 		granularity: granularity,
 		hostname:    hostname,
@@ -293,7 +276,6 @@ func New(backend Backend, encapsulate Encapsulate, granularity Granularity, host
 		stop:        make(chan struct{}),
 		subnet:      subnet,
 		table:       route.NewTable(),
-		tunlIface:   tunlIface,
 		errorCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
 			Name: "kilo_errors_total",
 			Help: "Number of errors that occurred while administering the mesh.",
@@ -319,6 +301,13 @@ func (m *Mesh) Run() error {
 	if err := m.Nodes().Init(m.stop); err != nil {
 		return fmt.Errorf("failed to initialize node backend: %v", err)
 	}
+	// Try to set the CNI config quickly.
+	if n, err := m.Nodes().Get(m.hostname); err == nil {
+		if n != nil && n.Subnet != nil {
+			m.nodes[m.hostname] = n
+			m.updateCNIConfig()
+		}
+	}
 	if err := m.Peers().Init(m.stop); err != nil {
 		return fmt.Errorf("failed to initialize peer backend: %v", err)
 	}
@@ -616,7 +605,7 @@ func (m *Mesh) applyTopology() {
 	rules = append(rules, iptables.MasqueradeRules(m.subnet, oneAddressCIDR(t.privateIP.IP), nodes[m.hostname].Subnet, t.RemoteSubnets(), peerCIDRs)...)
 	// If we are handling local routes, ensure the local
 	// tunnel has an IP address and IPIP traffic is allowed.
-	if m.encapsulate != NeverEncapsulate && m.local {
+	if m.enc.Strategy() != encapsulation.Never && m.local {
 		var cidrs []*net.IPNet
 		for _, s := range t.segments {
 			if s.location == nodes[m.hostname].Location {
@@ -626,11 +615,11 @@ func (m *Mesh) applyTopology() {
 				break
 			}
 		}
-		rules = append(rules, iptables.EncapsulateRules(cidrs)...)
+		rules = append(rules, m.enc.Rules(cidrs)...)
 
 		// If we are handling local routes, ensure the local
 		// tunnel has an IP address.
-		if err := iproute.SetAddress(m.tunlIface, oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
+		if err := m.enc.Set(oneAddressCIDR(newAllocator(*nodes[m.hostname].Subnet).next().IP)); err != nil {
 			level.Error(m.logger).Log("error", err)
 			m.errorCounter.WithLabelValues("apply").Inc()
 			return
@@ -685,7 +674,7 @@ func (m *Mesh) applyTopology() {
 	}
 	// We need to add routes last since they may depend
 	// on the WireGuard interface.
-	routes := t.Routes(m.kiloIface, m.privIface, m.tunlIface, m.local, m.encapsulate)
+	routes := t.Routes(m.kiloIface, m.privIface, m.enc.Index(), m.local, m.enc.Strategy())
 	if err := m.table.Set(routes); err != nil {
 		level.Error(m.logger).Log("error", err)
 		m.errorCounter.WithLabelValues("apply").Inc()
@@ -733,6 +722,10 @@ func (m *Mesh) cleanUp() {
 		level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up peer backend: %v", err))
 		m.errorCounter.WithLabelValues("cleanUp").Inc()
 	}
+	if err := m.enc.CleanUp(); err != nil {
+		level.Error(m.logger).Log("error", fmt.Sprintf("failed to clean up encapsulation: %v", err))
+		m.errorCounter.WithLabelValues("cleanUp").Inc()
+	}
 }
 
 func isSelf(hostname string, node *Node) bool {

+ 4 - 3
pkg/mesh/topology.go

@@ -19,6 +19,7 @@ import (
 	"net"
 	"sort"
 
+	"github.com/squat/kilo/pkg/encapsulation"
 	"github.com/squat/kilo/pkg/wireguard"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
@@ -171,7 +172,7 @@ func (t *Topology) RemoteSubnets() []*net.IPNet {
 }
 
 // Routes generates a slice of routes for a given Topology.
-func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encapsulate Encapsulate) []*netlink.Route {
+func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encapsulate encapsulation.Strategy) []*netlink.Route {
 	var routes []*netlink.Route
 	if !t.leader {
 		// Find the leader for this segment.
@@ -306,8 +307,8 @@ func (t *Topology) Routes(kiloIface, privIface, tunlIface int, local bool, encap
 	return routes
 }
 
-func encapsulateRoute(route *netlink.Route, encapsulate Encapsulate, subnet *net.IPNet, tunlIface int) *netlink.Route {
-	if encapsulate == AlwaysEncapsulate || (encapsulate == CrossSubnetEncapsulate && !subnet.Contains(route.Gw)) {
+func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy, subnet *net.IPNet, tunlIface int) *netlink.Route {
+	if encapsulate == encapsulation.Always || (encapsulate == encapsulation.CrossSubnet && !subnet.Contains(route.Gw)) {
 		route.LinkIndex = tunlIface
 	}
 	return route

+ 2 - 1
pkg/mesh/topology_test.go

@@ -20,6 +20,7 @@ import (
 	"testing"
 
 	"github.com/kylelemons/godebug/pretty"
+	"github.com/squat/kilo/pkg/encapsulation"
 	"github.com/squat/kilo/pkg/wireguard"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
@@ -978,7 +979,7 @@ func TestRoutes(t *testing.T) {
 			},
 		},
 	} {
-		routes := tc.topology.Routes(kiloIface, privIface, pubIface, tc.local, NeverEncapsulate)
+		routes := tc.topology.Routes(kiloIface, privIface, pubIface, tc.local, encapsulation.Never)
 		if diff := pretty.Compare(routes, tc.result); diff != "" {
 			t.Errorf("test case %q: got diff: %v", tc.name, diff)
 		}