conf.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. // Copyright 2019 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package wireguard
  15. import (
  16. "bytes"
  17. "errors"
  18. "fmt"
  19. "net"
  20. "sort"
  21. "strconv"
  22. "time"
  23. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  24. "k8s.io/apimachinery/pkg/util/validation"
  25. )
  26. type section string
  27. type key string
  28. const (
  29. interfaceSection section = "Interface"
  30. peerSection section = "Peer"
  31. listenPortKey key = "ListenPort"
  32. allowedIPsKey key = "AllowedIPs"
  33. endpointKey key = "Endpoint"
  34. persistentKeepaliveKey key = "PersistentKeepalive"
  35. presharedKeyKey key = "PresharedKey"
  36. privateKeyKey key = "PrivateKey"
  37. publicKeyKey key = "PublicKey"
  38. )
  39. // Conf represents a WireGuard configuration file.
  40. type Conf struct {
  41. wgtypes.Config
  42. // The Peers field is shadowed because every Peer needs the Endpoint field that contains a DNS endpoint.
  43. Peers []Peer
  44. }
  45. // WGConfig returns a wgytpes.Config from a Conf.
  46. func (c *Conf) WGConfig() wgtypes.Config {
  47. if c == nil {
  48. // The empty Config will do nothing, when applied.
  49. return wgtypes.Config{}
  50. }
  51. r := c.Config
  52. wgPs := make([]wgtypes.PeerConfig, len(c.Peers))
  53. for i, p := range c.Peers {
  54. wgPs[i] = p.PeerConfig
  55. if p.Endpoint.Resolved() {
  56. // We can ingore the error because we already checked if the Endpoint was resolved in the above line.
  57. wgPs[i].Endpoint, _ = p.Endpoint.UDPAddr(false)
  58. }
  59. wgPs[i].ReplaceAllowedIPs = true
  60. }
  61. r.Peers = wgPs
  62. r.ReplacePeers = true
  63. return r
  64. }
  65. // Endpoint represents a WireGuard endpoint.
  66. type Endpoint struct {
  67. udpAddr *net.UDPAddr
  68. addr string
  69. }
  70. // ParseEndpoint returns an Endpoint from a string.
  71. // The input should look like "10.0.0.0:100", "[ff10::10]:100"
  72. // or "example.com:100".
  73. func ParseEndpoint(endpoint string) *Endpoint {
  74. if len(endpoint) == 0 {
  75. return nil
  76. }
  77. hostRaw, portRaw, err := net.SplitHostPort(endpoint)
  78. if err != nil {
  79. return nil
  80. }
  81. port, err := strconv.ParseUint(portRaw, 10, 32)
  82. if err != nil {
  83. return nil
  84. }
  85. if len(validation.IsValidPortNum(int(port))) != 0 {
  86. return nil
  87. }
  88. ip := net.ParseIP(hostRaw)
  89. if ip == nil {
  90. if len(validation.IsDNS1123Subdomain(hostRaw)) == 0 {
  91. return &Endpoint{
  92. addr: endpoint,
  93. }
  94. }
  95. return nil
  96. }
  97. // ResolveUDPAddr will not resolve the endpoint as long as a valid IP and port is given.
  98. // This should be the case here.
  99. u, err := net.ResolveUDPAddr("udp", endpoint)
  100. if err != nil {
  101. return nil
  102. }
  103. u.IP = cutIP(u.IP)
  104. return &Endpoint{
  105. udpAddr: u,
  106. }
  107. }
  108. // NewEndpointFromUDPAddr returns an Endpoint from a net.UDPAddr.
  109. func NewEndpointFromUDPAddr(u *net.UDPAddr) *Endpoint {
  110. if u != nil {
  111. u.IP = cutIP(u.IP)
  112. }
  113. return &Endpoint{
  114. udpAddr: u,
  115. }
  116. }
  117. // NewEndpoint returns an Endpoint from a net.IP and port.
  118. func NewEndpoint(ip net.IP, port int) *Endpoint {
  119. return &Endpoint{
  120. udpAddr: &net.UDPAddr{
  121. IP: cutIP(ip),
  122. Port: port,
  123. },
  124. }
  125. }
  126. // Ready return true, if the Enpoint is ready.
  127. // Ready means that an IP or DN and port exists.
  128. func (e *Endpoint) Ready() bool {
  129. if e == nil {
  130. return false
  131. }
  132. return (e.udpAddr != nil && e.udpAddr.IP != nil && e.udpAddr.Port > 0) || len(e.addr) > 0
  133. }
  134. // Port returns the port of the Endpoint.
  135. func (e *Endpoint) Port() int {
  136. if !e.Ready() {
  137. return 0
  138. }
  139. if e.udpAddr != nil {
  140. return e.udpAddr.Port
  141. }
  142. // We can ignore the errors here bacause the returned port will be "".
  143. // This will result to Port 0 after the conversion to and int.
  144. _, p, _ := net.SplitHostPort(e.addr)
  145. port, _ := strconv.ParseUint(p, 10, 32)
  146. return int(port)
  147. }
  148. // HasDNS returns true if the endpoint has a DN.
  149. func (e *Endpoint) HasDNS() bool {
  150. return e != nil && e.addr != ""
  151. }
  152. // DNS returns the DN of the Endpoint.
  153. func (e *Endpoint) DNS() string {
  154. if e == nil {
  155. return ""
  156. }
  157. _, s, _ := net.SplitHostPort(e.addr)
  158. return s
  159. }
  160. // Resolved returns true, if the DN of the Endpoint was resolved
  161. // or if the Endpoint has a resolved endpoint.
  162. func (e *Endpoint) Resolved() bool {
  163. return e != nil && e.udpAddr != nil
  164. }
  165. // UDPAddr returns the UDPAddr of the Endpoint. If resolve is false,
  166. // UDPAddr() will not try to resolve a DN name, if the Endpoint is not yet resolved.
  167. func (e *Endpoint) UDPAddr(resolve bool) (*net.UDPAddr, error) {
  168. if !e.Ready() {
  169. return nil, errors.New("endpoint is not ready")
  170. }
  171. if e.udpAddr != nil {
  172. // Make a copy of the UDPAddr to protect it from modification outside this package.
  173. h := *e.udpAddr
  174. return &h, nil
  175. }
  176. if !resolve {
  177. return nil, errors.New("endpoint is not resolved")
  178. }
  179. var err error
  180. if e.udpAddr, err = net.ResolveUDPAddr("udp", e.addr); err != nil {
  181. return nil, err
  182. }
  183. // Make a copy of the UDPAddr to protect it from modification outside this package.
  184. h := *e.udpAddr
  185. return &h, nil
  186. }
  187. // IP returns the IP address of the Enpoint or nil.
  188. func (e *Endpoint) IP() net.IP {
  189. if !e.Resolved() {
  190. return nil
  191. }
  192. return e.udpAddr.IP
  193. }
  194. // String will return the endpoint as a string.
  195. // If a DN exists, it will take prcedence over the resolved endpoint.
  196. func (e *Endpoint) String() string {
  197. return e.StringOpt(true)
  198. }
  199. // StringOpt will return the string of the Endpoint.
  200. // If dnsFirst is false, the resolved Endpoint will
  201. // take precedence over the DN.
  202. func (e *Endpoint) StringOpt(dnsFirst bool) string {
  203. if e == nil {
  204. return ""
  205. }
  206. if e.udpAddr != nil && (!dnsFirst || e.addr == "") {
  207. return e.udpAddr.String()
  208. }
  209. return e.addr
  210. }
  211. // Equal will return true, if the Enpoints are equal.
  212. // If dnsFirst is false, the DN will only be compared if
  213. // the IPs are nil.
  214. func (e *Endpoint) Equal(b *Endpoint, dnsFirst bool) bool {
  215. return e.StringOpt(dnsFirst) == b.StringOpt(dnsFirst)
  216. }
  217. // Peer represents a `peer` section of a WireGuard configuration.
  218. type Peer struct {
  219. wgtypes.PeerConfig
  220. Endpoint *Endpoint
  221. }
  222. // DeduplicateIPs eliminates duplicate allowed IPs.
  223. func (p *Peer) DeduplicateIPs() {
  224. var ips []net.IPNet
  225. seen := make(map[string]struct{})
  226. for _, ip := range p.AllowedIPs {
  227. if _, ok := seen[ip.String()]; ok {
  228. continue
  229. }
  230. ips = append(ips, ip)
  231. seen[ip.String()] = struct{}{}
  232. }
  233. p.AllowedIPs = ips
  234. }
  235. // Bytes renders a WireGuard configuration to bytes.
  236. func (c *Conf) Bytes() ([]byte, error) {
  237. if c == nil {
  238. return nil, nil
  239. }
  240. var err error
  241. buf := bytes.NewBuffer(make([]byte, 0, 512))
  242. if c.PrivateKey != nil {
  243. if err = writeSection(buf, interfaceSection); err != nil {
  244. return nil, fmt.Errorf("failed to write interface: %v", err)
  245. }
  246. if err = writePKey(buf, privateKeyKey, c.PrivateKey); err != nil {
  247. return nil, fmt.Errorf("failed to write private key: %v", err)
  248. }
  249. if err = writeValue(buf, listenPortKey, strconv.Itoa(*c.ListenPort)); err != nil {
  250. return nil, fmt.Errorf("failed to write listen port: %v", err)
  251. }
  252. }
  253. for i, p := range c.Peers {
  254. // Add newlines to make the formatting nicer.
  255. if i == 0 && c.PrivateKey != nil || i != 0 {
  256. if err = buf.WriteByte('\n'); err != nil {
  257. return nil, err
  258. }
  259. }
  260. if err = writeSection(buf, peerSection); err != nil {
  261. return nil, fmt.Errorf("failed to write interface: %v", err)
  262. }
  263. if err = writeAllowedIPs(buf, p.AllowedIPs); err != nil {
  264. return nil, fmt.Errorf("failed to write allowed IPs: %v", err)
  265. }
  266. if err = writeEndpoint(buf, p.Endpoint); err != nil {
  267. return nil, fmt.Errorf("failed to write endpoint: %v", err)
  268. }
  269. if p.PersistentKeepaliveInterval == nil {
  270. p.PersistentKeepaliveInterval = new(time.Duration)
  271. }
  272. if err = writeValue(buf, persistentKeepaliveKey, strconv.FormatUint(uint64(*p.PersistentKeepaliveInterval/time.Second), 10)); err != nil {
  273. return nil, fmt.Errorf("failed to write persistent keepalive: %v", err)
  274. }
  275. if err = writePKey(buf, presharedKeyKey, p.PresharedKey); err != nil {
  276. return nil, fmt.Errorf("failed to write preshared key: %v", err)
  277. }
  278. if err = writePKey(buf, publicKeyKey, &p.PublicKey); err != nil {
  279. return nil, fmt.Errorf("failed to write public key: %v", err)
  280. }
  281. }
  282. return buf.Bytes(), nil
  283. }
  284. // Equal returns true if the Conf and wgtypes.Device are equal.
  285. func (c *Conf) Equal(d *wgtypes.Device) (bool, string) {
  286. if c == nil || d == nil {
  287. return c == nil && d == nil, "nil values"
  288. }
  289. if c.ListenPort == nil || *c.ListenPort != d.ListenPort {
  290. return false, fmt.Sprintf("port: old=%q, new=\"%v\"", d.ListenPort, c.ListenPort)
  291. }
  292. if c.PrivateKey == nil || *c.PrivateKey != d.PrivateKey {
  293. return false, fmt.Sprintf("private key: old=\"%s...\", new=\"%s\"", d.PrivateKey.String()[0:5], c.PrivateKey.String()[0:5])
  294. }
  295. if len(c.Peers) != len(d.Peers) {
  296. return false, fmt.Sprintf("number of peers: old=%d, new=%d", len(d.Peers), len(c.Peers))
  297. }
  298. sortPeerConfigs(d.Peers)
  299. sortPeers(c.Peers)
  300. for i := range c.Peers {
  301. if len(c.Peers[i].AllowedIPs) != len(d.Peers[i].AllowedIPs) {
  302. return false, fmt.Sprintf("Peer %d allowed IP length: old=%d, new=%d", i, len(d.Peers[i].AllowedIPs), len(c.Peers[i].AllowedIPs))
  303. }
  304. sortCIDRs(c.Peers[i].AllowedIPs)
  305. sortCIDRs(d.Peers[i].AllowedIPs)
  306. for j := range c.Peers[i].AllowedIPs {
  307. if c.Peers[i].AllowedIPs[j].String() != d.Peers[i].AllowedIPs[j].String() {
  308. return false, fmt.Sprintf("Peer %d allowed IP: old=%q, new=%q", i, d.Peers[i].AllowedIPs[j].String(), c.Peers[i].AllowedIPs[j].String())
  309. }
  310. }
  311. if c.Peers[i].Endpoint == nil || d.Peers[i].Endpoint == nil {
  312. return c.Peers[i].Endpoint == nil && d.Peers[i].Endpoint == nil, "peer endpoints: nil value"
  313. }
  314. if c.Peers[i].Endpoint.StringOpt(false) != d.Peers[i].Endpoint.String() {
  315. return false, fmt.Sprintf("Peer %d endpoint: old=%q, new=%q", i, d.Peers[i].Endpoint.String(), c.Peers[i].Endpoint.StringOpt(false))
  316. }
  317. pki := time.Duration(0)
  318. if p := c.Peers[i].PersistentKeepaliveInterval; p != nil {
  319. pki = *p
  320. }
  321. psk := wgtypes.Key{}
  322. if p := c.Peers[i].PresharedKey; p != nil {
  323. psk = *p
  324. }
  325. if pki != d.Peers[i].PersistentKeepaliveInterval || psk != d.Peers[i].PresharedKey || c.Peers[i].PublicKey != d.Peers[i].PublicKey {
  326. return false, "persistent keepalive or pershared key"
  327. }
  328. }
  329. return true, ""
  330. }
  331. func sortPeerConfigs(peers []wgtypes.Peer) {
  332. sort.Slice(peers, func(i, j int) bool {
  333. return peers[i].PublicKey.String() < peers[j].PublicKey.String()
  334. })
  335. }
  336. func sortPeers(peers []Peer) {
  337. sort.Slice(peers, func(i, j int) bool {
  338. return peers[i].PublicKey.String() < peers[j].PublicKey.String()
  339. })
  340. }
  341. func sortCIDRs(cidrs []net.IPNet) {
  342. sort.Slice(cidrs, func(i, j int) bool {
  343. return cidrs[i].String() < cidrs[j].String()
  344. })
  345. }
  346. func cutIP(ip net.IP) net.IP {
  347. if i4 := ip.To4(); i4 != nil {
  348. return i4
  349. }
  350. return ip.To16()
  351. }
  352. func writeAllowedIPs(buf *bytes.Buffer, ais []net.IPNet) error {
  353. if len(ais) == 0 {
  354. return nil
  355. }
  356. var err error
  357. if err = writeKey(buf, allowedIPsKey); err != nil {
  358. return err
  359. }
  360. for i := range ais {
  361. if i != 0 {
  362. if _, err = buf.WriteString(", "); err != nil {
  363. return err
  364. }
  365. }
  366. if _, err = buf.WriteString(ais[i].String()); err != nil {
  367. return err
  368. }
  369. }
  370. return buf.WriteByte('\n')
  371. }
  372. func writePKey(buf *bytes.Buffer, k key, b *wgtypes.Key) error {
  373. // Print nothing if the public key was never initialized.
  374. if b == nil || (wgtypes.Key{}) == *b {
  375. return nil
  376. }
  377. var err error
  378. if err = writeKey(buf, k); err != nil {
  379. return err
  380. }
  381. if _, err = buf.Write([]byte(b.String())); err != nil {
  382. return err
  383. }
  384. return buf.WriteByte('\n')
  385. }
  386. func writeValue(buf *bytes.Buffer, k key, v string) error {
  387. var err error
  388. if err = writeKey(buf, k); err != nil {
  389. return err
  390. }
  391. if _, err = buf.WriteString(v); err != nil {
  392. return err
  393. }
  394. return buf.WriteByte('\n')
  395. }
  396. func writeEndpoint(buf *bytes.Buffer, e *Endpoint) error {
  397. str := e.String()
  398. if str == "" {
  399. return nil
  400. }
  401. var err error
  402. if err = writeKey(buf, endpointKey); err != nil {
  403. return err
  404. }
  405. if _, err = buf.WriteString(str); err != nil {
  406. return err
  407. }
  408. return buf.WriteByte('\n')
  409. }
  410. func writeSection(buf *bytes.Buffer, s section) error {
  411. var err error
  412. if err = buf.WriteByte('['); err != nil {
  413. return err
  414. }
  415. if _, err = buf.WriteString(string(s)); err != nil {
  416. return err
  417. }
  418. if err = buf.WriteByte(']'); err != nil {
  419. return err
  420. }
  421. return buf.WriteByte('\n')
  422. }
  423. func writeKey(buf *bytes.Buffer, k key) error {
  424. var err error
  425. if _, err = buf.WriteString(string(k)); err != nil {
  426. return err
  427. }
  428. _, err = buf.WriteString(" = ")
  429. return err
  430. }