| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421 |
- // Copyright 2019 the Kilo authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package wireguard
- import (
- "bufio"
- "bytes"
- "fmt"
- "net"
- "sort"
- "strconv"
- "strings"
- "k8s.io/apimachinery/pkg/util/validation"
- )
- type section string
- type key string
- const (
- separator = "="
- interfaceSection section = "Interface"
- peerSection section = "Peer"
- listenPortKey key = "ListenPort"
- allowedIPsKey key = "AllowedIPs"
- endpointKey key = "Endpoint"
- persistentKeepaliveKey key = "PersistentKeepalive"
- privateKeyKey key = "PrivateKey"
- publicKeyKey key = "PublicKey"
- )
- // Conf represents a WireGuard configuration file.
- type Conf struct {
- Interface *Interface
- Peers []*Peer
- }
- // Interface represents the `interface` section of a WireGuard configuration.
- type Interface struct {
- ListenPort uint32
- PrivateKey []byte
- }
- // Peer represents a `peer` section of a WireGuard configuration.
- type Peer struct {
- AllowedIPs []*net.IPNet
- Endpoint *Endpoint
- PersistentKeepalive int
- PublicKey []byte
- }
- // DeduplicateIPs eliminates duplicate allowed IPs.
- func (p *Peer) DeduplicateIPs() {
- var ips []*net.IPNet
- seen := make(map[string]struct{})
- for _, ip := range p.AllowedIPs {
- if _, ok := seen[ip.String()]; ok {
- continue
- }
- ips = append(ips, ip)
- seen[ip.String()] = struct{}{}
- }
- p.AllowedIPs = ips
- }
- // Endpoint represents an `endpoint` key of a `peer` section.
- type Endpoint struct {
- DNSOrIP
- Port uint32
- }
- // String prints the string representation of the endpoint.
- func (e *Endpoint) String() string {
- dnsOrIP := e.DNSOrIP.String()
- if e.IP != nil && len(e.IP) == net.IPv6len {
- dnsOrIP = "[" + dnsOrIP + "]"
- }
- return dnsOrIP + ":" + strconv.FormatUint(uint64(e.Port), 10)
- }
- // DNSOrIP represents either a DNS name or an IP address.
- // IPs, as they are more specific, are preferred.
- type DNSOrIP struct {
- DNS string
- IP net.IP
- }
- // String prints the string representation of the struct.
- func (d DNSOrIP) String() string {
- if d.IP != nil {
- return d.IP.String()
- }
- return d.DNS
- }
- // Parse parses a given WireGuard configuration file and produces a Conf struct.
- func Parse(buf []byte) *Conf {
- var (
- active section
- ai *net.IPNet
- kv []string
- c Conf
- err error
- iface *Interface
- i int
- ip, ip4 net.IP
- k key
- line, v string
- peer *Peer
- port uint64
- )
- s := bufio.NewScanner(bytes.NewBuffer(buf))
- for s.Scan() {
- line = strings.TrimSpace(s.Text())
- // Skip comments.
- if strings.HasPrefix(line, "#") {
- continue
- }
- // Line is a section title.
- if strings.HasPrefix(line, "[") {
- if peer != nil {
- c.Peers = append(c.Peers, peer)
- peer = nil
- }
- if iface != nil {
- c.Interface = iface
- iface = nil
- }
- active = section(strings.TrimSpace(strings.Trim(line, "[]")))
- switch active {
- case interfaceSection:
- iface = new(Interface)
- case peerSection:
- peer = new(Peer)
- }
- continue
- }
- kv = strings.SplitN(line, separator, 2)
- if len(kv) != 2 {
- continue
- }
- k = key(strings.TrimSpace(kv[0]))
- v = strings.TrimSpace(kv[1])
- switch active {
- case interfaceSection:
- switch k {
- case listenPortKey:
- port, err = strconv.ParseUint(v, 10, 32)
- if err != nil {
- continue
- }
- iface.ListenPort = uint32(port)
- case privateKeyKey:
- iface.PrivateKey = []byte(v)
- }
- case peerSection:
- switch k {
- case allowedIPsKey:
- // Reuse string slice.
- kv = strings.Split(v, ",")
- for i = range kv {
- ip, ai, err = net.ParseCIDR(strings.TrimSpace(kv[i]))
- if err != nil {
- continue
- }
- if ip4 = ip.To4(); ip4 != nil {
- ip = ip4
- } else {
- ip = ip.To16()
- }
- ai.IP = ip
- peer.AllowedIPs = append(peer.AllowedIPs, ai)
- }
- case endpointKey:
- // Reuse string slice.
- kv = strings.Split(v, ":")
- if len(kv) < 2 {
- continue
- }
- port, err = strconv.ParseUint(kv[len(kv)-1], 10, 32)
- if err != nil {
- continue
- }
- d := DNSOrIP{}
- ip = net.ParseIP(strings.Trim(strings.Join(kv[:len(kv)-1], ":"), "[]"))
- if ip == nil {
- if len(validation.IsDNS1123Subdomain(kv[0])) != 0 {
- continue
- }
- d.DNS = kv[0]
- } else {
- if ip4 = ip.To4(); ip4 != nil {
- d.IP = ip4
- } else {
- d.IP = ip.To16()
- }
- }
- peer.Endpoint = &Endpoint{
- DNSOrIP: d,
- Port: uint32(port),
- }
- case persistentKeepaliveKey:
- i, err = strconv.Atoi(v)
- if err != nil {
- continue
- }
- peer.PersistentKeepalive = i
- case publicKeyKey:
- peer.PublicKey = []byte(v)
- }
- }
- }
- if peer != nil {
- c.Peers = append(c.Peers, peer)
- }
- if iface != nil {
- c.Interface = iface
- }
- return &c
- }
- // Bytes renders a WireGuard configuration to bytes.
- func (c *Conf) Bytes() ([]byte, error) {
- var err error
- buf := bytes.NewBuffer(make([]byte, 0, 512))
- if c.Interface != nil {
- if err = writeSection(buf, interfaceSection); err != nil {
- return nil, fmt.Errorf("failed to write interface: %v", err)
- }
- if err = writePKey(buf, privateKeyKey, c.Interface.PrivateKey); err != nil {
- return nil, fmt.Errorf("failed to write private key: %v", err)
- }
- if err = writeValue(buf, listenPortKey, strconv.FormatUint(uint64(c.Interface.ListenPort), 10)); err != nil {
- return nil, fmt.Errorf("failed to write listen port: %v", err)
- }
- }
- for i, p := range c.Peers {
- // Add newlines to make the formatting nicer.
- if i == 0 && c.Interface != nil || i != 0 {
- if err = buf.WriteByte('\n'); err != nil {
- return nil, err
- }
- }
- if err = writeSection(buf, peerSection); err != nil {
- return nil, fmt.Errorf("failed to write interface: %v", err)
- }
- if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
- return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
- }
- if err = writeEndpoint(buf, p.Endpoint); err != nil {
- return nil, fmt.Errorf("failed to write endpoint: %v", err)
- }
- if err = writeValue(buf, persistentKeepaliveKey, strconv.Itoa(p.PersistentKeepalive)); err != nil {
- return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
- }
- if err = writePKey(buf, publicKeyKey, p.PublicKey); err != nil {
- return nil, fmt.Errorf("failed to write public key: %v", err)
- }
- }
- return buf.Bytes(), nil
- }
- // Equal checks if two WireGuard configurations are equivalent.
- func (c *Conf) Equal(b *Conf) bool {
- if (c.Interface == nil) != (b.Interface == nil) {
- return false
- }
- if c.Interface != nil {
- if c.Interface.ListenPort != b.Interface.ListenPort || !bytes.Equal(c.Interface.PrivateKey, b.Interface.PrivateKey) {
- return false
- }
- }
- if len(c.Peers) != len(b.Peers) {
- return false
- }
- sortPeers(c.Peers)
- sortPeers(b.Peers)
- for i := range c.Peers {
- if len(c.Peers[i].AllowedIPs) != len(b.Peers[i].AllowedIPs) {
- return false
- }
- sortCIDRs(c.Peers[i].AllowedIPs)
- sortCIDRs(b.Peers[i].AllowedIPs)
- for j := range c.Peers[i].AllowedIPs {
- if c.Peers[i].AllowedIPs[j].String() != b.Peers[i].AllowedIPs[j].String() {
- return false
- }
- }
- if (c.Peers[i].Endpoint == nil) != (b.Peers[i].Endpoint == nil) {
- return false
- }
- if c.Peers[i].Endpoint != nil {
- if c.Peers[i].Endpoint.Port != b.Peers[i].Endpoint.Port {
- return false
- }
- // IPs take priority, so check them first.
- if !c.Peers[i].Endpoint.IP.Equal(b.Peers[i].Endpoint.IP) {
- return false
- }
- // Only check the DNS name if the IP is empty.
- if c.Peers[i].Endpoint.IP == nil && c.Peers[i].Endpoint.DNS != b.Peers[i].Endpoint.DNS {
- return false
- }
- }
- if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {
- return false
- }
- }
- return true
- }
- func sortPeers(peers []*Peer) {
- sort.Slice(peers, func(i, j int) bool {
- if bytes.Compare(peers[i].PublicKey, peers[j].PublicKey) < 0 {
- return true
- }
- return false
- })
- }
- func sortCIDRs(cidrs []*net.IPNet) {
- sort.Slice(cidrs, func(i, j int) bool {
- return cidrs[i].String() < cidrs[j].String()
- })
- }
- func writeAllowedIPs(buf *bytes.Buffer, ais []*net.IPNet) error {
- if len(ais) == 0 {
- return nil
- }
- var err error
- if err = writeKey(buf, allowedIPsKey); err != nil {
- return err
- }
- for i := range ais {
- if i != 0 {
- if _, err = buf.WriteString(", "); err != nil {
- return err
- }
- }
- if _, err = buf.WriteString(ais[i].String()); err != nil {
- return err
- }
- }
- return buf.WriteByte('\n')
- }
- func writePKey(buf *bytes.Buffer, k key, b []byte) error {
- if len(b) == 0 {
- return nil
- }
- var err error
- if err = writeKey(buf, k); err != nil {
- return err
- }
- if _, err = buf.Write(b); err != nil {
- return err
- }
- return buf.WriteByte('\n')
- }
- func writeValue(buf *bytes.Buffer, k key, v string) error {
- var err error
- if err = writeKey(buf, k); err != nil {
- return err
- }
- if _, err = buf.WriteString(v); err != nil {
- return err
- }
- return buf.WriteByte('\n')
- }
- func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
- if e == nil {
- return nil
- }
- var err error
- if err = writeKey(buf, endpointKey); err != nil {
- return err
- }
- if _, err = buf.WriteString(e.String()); err != nil {
- return err
- }
- return buf.WriteByte('\n')
- }
- func writeSection(buf *bytes.Buffer, s section) error {
- var err error
- if err = buf.WriteByte('['); err != nil {
- return err
- }
- if _, err = buf.WriteString(string(s)); err != nil {
- return err
- }
- if err = buf.WriteByte(']'); err != nil {
- return err
- }
- return buf.WriteByte('\n')
- }
- func writeKey(buf *bytes.Buffer, k key) error {
- var err error
- if _, err = buf.WriteString(string(k)); err != nil {
- return err
- }
- _, err = buf.WriteString(" = ")
- return err
- }
|