Sfoglia il codice sorgente

Use LatestHandshake to validate endpoint (#149)

* wireguard: `wg show iface dump` reader and parser

* mesh: use LatestHandshake to validate NAT Endpoints

* add skip on error

* switch to loop parsing

So the stop on error pattern can be used

* Add error handling to ParseDump
Julien Viard de Galbert 4 anni fa
parent
commit
e12b5029d7
4 ha cambiato i file con 277 aggiunte e 47 eliminazioni
  1. 10 7
      pkg/mesh/mesh.go
  2. 209 40
      pkg/wireguard/conf.go
  3. 46 0
      pkg/wireguard/conf_test.go
  4. 12 0
      pkg/wireguard/wireguard.go

+ 10 - 7
pkg/mesh/mesh.go

@@ -454,13 +454,18 @@ func (m *Mesh) applyTopology() {
 		return
 		return
 	}
 	}
 	// Find the old configuration.
 	// Find the old configuration.
-	oldConfRaw, err := wireguard.ShowConf(link.Attrs().Name)
+	oldConfDump, err := wireguard.ShowDump(link.Attrs().Name)
+	if err != nil {
+		level.Error(m.logger).Log("error", err)
+		m.errorCounter.WithLabelValues("apply").Inc()
+		return
+	}
+	oldConf, err := wireguard.ParseDump(oldConfDump)
 	if err != nil {
 	if err != nil {
 		level.Error(m.logger).Log("error", err)
 		level.Error(m.logger).Log("error", err)
 		m.errorCounter.WithLabelValues("apply").Inc()
 		m.errorCounter.WithLabelValues("apply").Inc()
 		return
 		return
 	}
 	}
-	oldConf := wireguard.Parse(oldConfRaw)
 	natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
 	natEndpoints := discoverNATEndpoints(nodes, peers, oldConf, m.logger)
 	nodes[m.hostname].DiscoveredEndpoints = natEndpoints
 	nodes[m.hostname].DiscoveredEndpoints = natEndpoints
 	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
 	t, err := NewTopology(nodes, peers, m.granularity, m.hostname, nodes[m.hostname].Endpoint.Port, m.priv, m.subnet, nodes[m.hostname].PersistentKeepalive, m.logger)
@@ -782,17 +787,15 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *
 	}
 	}
 	for _, n := range nodes {
 	for _, n := range nodes {
 		if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
 		if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
-			level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
-			// Should check location leader but only available in topology ... or have topology handle that list
-			// Better check wg latest-handshake
-			if !n.Endpoint.Equal(peer.Endpoint, false) {
+			level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false), "latest-handshake", peer.LatestHandshake)
+			if (peer.LatestHandshake != time.Time{}) {
 				natEndpoints[string(n.Key)] = peer.Endpoint
 				natEndpoints[string(n.Key)] = peer.Endpoint
 			}
 			}
 		}
 		}
 	}
 	}
 	for _, p := range peers {
 	for _, p := range peers {
 		if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
 		if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
-			if !p.Endpoint.Equal(peer.Endpoint, false) {
+			if (peer.LatestHandshake != time.Time{}) {
 				natEndpoints[string(p.PublicKey)] = peer.Endpoint
 				natEndpoints[string(p.PublicKey)] = peer.Endpoint
 			}
 			}
 		}
 		}

+ 209 - 40
pkg/wireguard/conf.go

@@ -17,11 +17,13 @@ package wireguard
 import (
 import (
 	"bufio"
 	"bufio"
 	"bytes"
 	"bytes"
+	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"sort"
 	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+	"time"
 
 
 	"k8s.io/apimachinery/pkg/util/validation"
 	"k8s.io/apimachinery/pkg/util/validation"
 )
 )
@@ -31,6 +33,9 @@ type key string
 
 
 const (
 const (
 	separator                      = "="
 	separator                      = "="
+	dumpSeparator                  = "\t"
+	dumpNone                       = "(none)"
+	dumpOff                        = "off"
 	interfaceSection       section = "Interface"
 	interfaceSection       section = "Interface"
 	peerSection            section = "Peer"
 	peerSection            section = "Peer"
 	listenPortKey          key     = "ListenPort"
 	listenPortKey          key     = "ListenPort"
@@ -42,6 +47,30 @@ const (
 	publicKeyKey           key     = "PublicKey"
 	publicKeyKey           key     = "PublicKey"
 )
 )
 
 
+type dumpInterfaceIndex int
+
+const (
+	dumpInterfacePrivateKeyIndex = iota
+	dumpInterfacePublicKeyIndex
+	dumpInterfaceListenPortIndex
+	dumpInterfaceFWMarkIndex
+	dumpInterfaceLen
+)
+
+type dumpPeerIndex int
+
+const (
+	dumpPeerPublicKeyIndex = iota
+	dumpPeerPresharedKeyIndex
+	dumpPeerEndpointIndex
+	dumpPeerAllowedIPsIndex
+	dumpPeerLatestHandshakeIndex
+	dumpPeerTransferRXIndex
+	dumpPeerTransferTXIndex
+	dumpPeerPersistentKeepaliveIndex
+	dumpPeerLen
+)
+
 // Conf represents a WireGuard configuration file.
 // Conf represents a WireGuard configuration file.
 type Conf struct {
 type Conf struct {
 	Interface *Interface
 	Interface *Interface
@@ -61,6 +90,8 @@ type Peer struct {
 	PersistentKeepalive int
 	PersistentKeepalive int
 	PresharedKey        []byte
 	PresharedKey        []byte
 	PublicKey           []byte
 	PublicKey           []byte
+	// The following fields are part of the runtime information, not the configuration.
+	LatestHandshake time.Time
 }
 }
 
 
 // DeduplicateIPs eliminates duplicate allowed IPs.
 // DeduplicateIPs eliminates duplicate allowed IPs.
@@ -146,13 +177,11 @@ func (d DNSOrIP) String() string {
 func Parse(buf []byte) *Conf {
 func Parse(buf []byte) *Conf {
 	var (
 	var (
 		active  section
 		active  section
-		ai      *net.IPNet
 		kv      []string
 		kv      []string
 		c       Conf
 		c       Conf
 		err     error
 		err     error
 		iface   *Interface
 		iface   *Interface
 		i       int
 		i       int
-		ip, ip4 net.IP
 		k       key
 		k       key
 		line, v string
 		line, v string
 		peer    *Peer
 		peer    *Peer
@@ -205,49 +234,15 @@ func Parse(buf []byte) *Conf {
 		case peerSection:
 		case peerSection:
 			switch k {
 			switch k {
 			case allowedIPsKey:
 			case allowedIPsKey:
-				// Reuse string slice.
-				kv = strings.Split(v, ",")
-				for i = range kv {
-					ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
-					if err != nil {
-						continue
-					}
-					if ip4 = ip.To4(); ip4 != nil {
-						ip = ip4
-					} else {
-						ip = ip.To16()
-					}
-					ai.IP = ip
-					peer.AllowedIPs = append(peer.AllowedIPs, ai)
-				}
-			case endpointKey:
-				// Reuse string slice.
-				kv = strings.Split(v, ":")
-				if len(kv) < 2 {
+				err = peer.parseAllowedIPs(v)
+				if err != nil {
 					continue
 					continue
 				}
 				}
-				port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
+			case endpointKey:
+				err = peer.parseEndpoint(v)
 				if err != nil {
 				if err != nil {
 					continue
 					continue
 				}
 				}
-				d := DNSOrIP{}
-				ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
-				if ip == nil {
-					if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
-						continue
-					}
-					d.DNS = kv[0]
-				} else {
-					if ip4 = ip.To4(); ip4 != nil {
-						d.IP = ip4
-					} else {
-						d.IP = ip.To16()
-					}
-				}
-				peer.Endpoint = &Endpoint{
-					DNSOrIP: d,
-					Port:    uint32(port),
-				}
 			case persistentKeepaliveKey:
 			case persistentKeepaliveKey:
 				i, err = strconv.Atoi(v)
 				i, err = strconv.Atoi(v)
 				if err != nil {
 				if err != nil {
@@ -448,3 +443,177 @@ func writeKey(buf *bytes.Buffer, k key) error {
 	_, err = buf.WriteString(" = ")
 	_, err = buf.WriteString(" = ")
 	return err
 	return err
 }
 }
+
+var (
+	errParseEndpoint = errors.New("could not parse Endpoint")
+)
+
+func (p *Peer) parseEndpoint(v string) error {
+	var (
+		kv      []string
+		err     error
+		ip, ip4 net.IP
+		port    uint64
+	)
+	kv = strings.Split(v, ":")
+	if len(kv) < 2 {
+		return errParseEndpoint
+	}
+	port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
+	if err != nil {
+		return err
+	}
+	d := DNSOrIP{}
+	ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
+	if ip == nil {
+		if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
+			return errParseEndpoint
+		}
+		d.DNS = kv[0]
+	} else {
+		if ip4 = ip.To4(); ip4 != nil {
+			d.IP = ip4
+		} else {
+			d.IP = ip.To16()
+		}
+	}
+
+	p.Endpoint = &Endpoint{
+		DNSOrIP: d,
+		Port:    uint32(port),
+	}
+	return nil
+}
+
+func (p *Peer) parseAllowedIPs(v string) error {
+	var (
+		ai      *net.IPNet
+		kv      []string
+		err     error
+		i       int
+		ip, ip4 net.IP
+	)
+
+	kv = strings.Split(v, ",")
+	for i = range kv {
+		ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
+		if err != nil {
+			return err
+		}
+		if ip4 = ip.To4(); ip4 != nil {
+			ip = ip4
+		} else {
+			ip = ip.To16()
+		}
+		ai.IP = ip
+		p.AllowedIPs = append(p.AllowedIPs, ai)
+	}
+	return nil
+}
+
+// ParseDump parses a given WireGuard dump and produces a Conf struct.
+func ParseDump(buf []byte) (*Conf, error) {
+	// from man wg, show section:
+	// If dump is specified, then several lines are printed;
+	// the first contains in order separated by tab: private-key, public-key, listen-port, fw‐mark.
+	// Subsequent lines are printed for each peer and contain in order separated by tab:
+	// public-key, preshared-key, endpoint, allowed-ips, latest-handshake, transfer-rx, transfer-tx, persistent-keepalive.
+	var (
+		active section
+		values []string
+		c      Conf
+		err    error
+		iface  *Interface
+		peer   *Peer
+		port   uint64
+		sec    int64
+		pka    int
+		line   int
+	)
+	// First line is Interface
+	active = interfaceSection
+	s := bufio.NewScanner(bytes.NewBuffer(buf))
+	for s.Scan() {
+		values = strings.Split(s.Text(), dumpSeparator)
+
+		switch active {
+		case interfaceSection:
+			if len(values) < dumpInterfaceLen {
+				return nil, fmt.Errorf("invalid interface line: missing fields (%d < %d)", len(values), dumpInterfaceLen)
+			}
+			iface = new(Interface)
+			for i := range values {
+				switch i {
+				case dumpInterfacePrivateKeyIndex:
+					iface.PrivateKey = []byte(values[i])
+				case dumpInterfaceListenPortIndex:
+					port, err = strconv.ParseUint(values[i], 10, 32)
+					if err != nil {
+						return nil, fmt.Errorf("invalid interface line: error parsing listen-port: %w", err)
+					}
+					iface.ListenPort = uint32(port)
+				}
+			}
+			c.Interface = iface
+			// Next lines are Peers
+			active = peerSection
+		case peerSection:
+			if len(values) < dumpPeerLen {
+				return nil, fmt.Errorf("invalid peer line %d: missing fields (%d < %d)", line, len(values), dumpPeerLen)
+			}
+			peer = new(Peer)
+
+			for i := range values {
+				switch i {
+				case dumpPeerPublicKeyIndex:
+					peer.PublicKey = []byte(values[i])
+				case dumpPeerPresharedKeyIndex:
+					if values[i] == dumpNone {
+						continue
+					}
+					peer.PresharedKey = []byte(values[i])
+				case dumpPeerEndpointIndex:
+					if values[i] == dumpNone {
+						continue
+					}
+					err = peer.parseEndpoint(values[i])
+					if err != nil {
+						return nil, fmt.Errorf("invalid peer line %d: error parsing endpoint: %w", line, err)
+					}
+				case dumpPeerAllowedIPsIndex:
+					if values[i] == dumpNone {
+						continue
+					}
+					err = peer.parseAllowedIPs(values[i])
+					if err != nil {
+						return nil, fmt.Errorf("invalid peer line %d: error parsing allowed-ips: %w", line, err)
+					}
+				case dumpPeerLatestHandshakeIndex:
+					if values[i] == "0" {
+						// Use go zero value, not unix 0 timestamp.
+						peer.LatestHandshake = time.Time{}
+						continue
+					}
+					sec, err = strconv.ParseInt(values[i], 10, 64)
+					if err != nil {
+						return nil, fmt.Errorf("invalid peer line %d: error parsing latest-handshake: %w", line, err)
+					}
+					peer.LatestHandshake = time.Unix(sec, 0)
+				case dumpPeerPersistentKeepaliveIndex:
+					if values[i] == dumpOff {
+						continue
+					}
+					pka, err = strconv.Atoi(values[i])
+					if err != nil {
+						return nil, fmt.Errorf("invalid peer line %d: error parsing persistent-keepalive: %w", line, err)
+					}
+					peer.PersistentKeepalive = pka
+				}
+			}
+			c.Peers = append(c.Peers, peer)
+			peer = nil
+		}
+		line++
+	}
+	return &c, nil
+}

+ 46 - 0
pkg/wireguard/conf_test.go

@@ -17,6 +17,8 @@ package wireguard
 import (
 import (
 	"net"
 	"net"
 	"testing"
 	"testing"
+
+	"github.com/kylelemons/godebug/pretty"
 )
 )
 
 
 func TestCompareConf(t *testing.T) {
 func TestCompareConf(t *testing.T) {
@@ -308,3 +310,47 @@ func TestCompareEndpoint(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+
+func TestCompareDumpConf(t *testing.T) {
+	for _, tc := range []struct {
+		name string
+		d    []byte
+		c    []byte
+	}{
+		{
+			name: "empty",
+			d:    []byte{},
+			c:    []byte{},
+		},
+		{
+			name: "redacted copy from wg output",
+			d: []byte(`private	B7qk8EMlob0nfado0ABM6HulUV607r4yqtBKjhap7S4=	51820	off
+key1	(none)	10.254.1.1:51820	100.64.1.0/24,192.168.0.125/32,10.4.0.1/32	1619012801	67048	34952	10
+key2	(none)	10.254.2.1:51820	100.64.4.0/24,10.69.76.55/32,100.64.3.0/24,10.66.25.131/32,10.4.0.2/32	1619013058	1134456	10077852	10`),
+			c: []byte(`[Interface]
+		ListenPort = 51820
+		PrivateKey = private
+
+		[Peer]
+		PublicKey = key1
+		AllowedIPs = 100.64.1.0/24, 192.168.0.125/32, 10.4.0.1/32
+		Endpoint = 10.254.1.1:51820
+		PersistentKeepalive = 10
+
+		[Peer]
+		PublicKey = key2
+		AllowedIPs = 100.64.4.0/24, 10.69.76.55/32, 100.64.3.0/24, 10.66.25.131/32, 10.4.0.2/32
+		Endpoint = 10.254.2.1:51820
+		PersistentKeepalive = 10`),
+		},
+	} {
+
+		dumpConf, _ := ParseDump(tc.d)
+		conf := Parse(tc.c)
+		// Equal will ignore runtime fields and only compare configuration fields.
+		if !dumpConf.Equal(conf) {
+			diff := pretty.Compare(dumpConf, conf)
+			t.Errorf("test case %q: got diff: %v", tc.name, diff)
+		}
+	}
+}

+ 12 - 0
pkg/wireguard/wireguard.go

@@ -119,3 +119,15 @@ func ShowConf(iface string) ([]byte, error) {
 	}
 	}
 	return stdout.Bytes(), nil
 	return stdout.Bytes(), nil
 }
 }
+
+// ShowDump gets the WireGuard configuration and runtime information for the given interface.
+func ShowDump(iface string) ([]byte, error) {
+	cmd := exec.Command("wg", "show", iface, "dump")
+	var stderr, stdout bytes.Buffer
+	cmd.Stderr = &stderr
+	cmd.Stdout = &stdout
+	if err := cmd.Run(); err != nil {
+		return nil, fmt.Errorf("failed to read the WireGuard dump output: %s", stderr.String())
+	}
+	return stdout.Bytes(), nil
+}