configure_linux.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. //go:build linux
  2. // +build linux
  3. package wglinux
  4. import (
  5. "encoding/binary"
  6. "fmt"
  7. "net"
  8. "unsafe"
  9. "github.com/mdlayher/netlink"
  10. "github.com/mdlayher/netlink/nlenc"
  11. "golang.org/x/sys/unix"
  12. "golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
  13. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  14. )
  15. // configAttrs creates the required encoded netlink attributes to configure
  16. // the device specified by name using the non-nil fields in cfg.
  17. func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) {
  18. ae := netlink.NewAttributeEncoder()
  19. ae.String(wgh.DeviceAIfname, name)
  20. if cfg.PrivateKey != nil {
  21. ae.Bytes(wgh.DeviceAPrivateKey, (*cfg.PrivateKey)[:])
  22. }
  23. if cfg.ListenPort != nil {
  24. ae.Uint16(wgh.DeviceAListenPort, uint16(*cfg.ListenPort))
  25. }
  26. if cfg.FirewallMark != nil {
  27. ae.Uint32(wgh.DeviceAFwmark, uint32(*cfg.FirewallMark))
  28. }
  29. if cfg.ReplacePeers {
  30. ae.Uint32(wgh.DeviceAFlags, wgh.DeviceFReplacePeers)
  31. }
  32. // Only apply peer attributes if necessary.
  33. if len(cfg.Peers) > 0 {
  34. ae.Nested(wgh.DeviceAPeers, func(nae *netlink.AttributeEncoder) error {
  35. // Netlink arrays use type as an array index.
  36. for i, p := range cfg.Peers {
  37. nae.Nested(uint16(i), encodePeer(p))
  38. }
  39. return nil
  40. })
  41. }
  42. return ae.Encode()
  43. }
  44. // ipBatchChunk is a tunable allowed IP batch limit per peer.
  45. //
  46. // Because we don't necessarily know how much space a given peer will occupy,
  47. // we play it safe and use a reasonably small value. Note that this constant
  48. // is used both in this package and tests, so be aware when making changes.
  49. const ipBatchChunk = 256
  50. // peerBatchChunk specifies the number of peers that can appear in a
  51. // configuration before we start splitting it into chunks.
  52. const peerBatchChunk = 32
  53. // shouldBatch determines if a configuration is sufficiently complex that it
  54. // should be split into batches.
  55. func shouldBatch(cfg wgtypes.Config) bool {
  56. if len(cfg.Peers) > peerBatchChunk {
  57. return true
  58. }
  59. var ips int
  60. for _, p := range cfg.Peers {
  61. ips += len(p.AllowedIPs)
  62. }
  63. return ips > ipBatchChunk
  64. }
  65. // buildBatches produces a batch of configs from a single config, if needed.
  66. func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
  67. // Is this a small configuration; no need to batch?
  68. if !shouldBatch(cfg) {
  69. return []wgtypes.Config{cfg}
  70. }
  71. // Use most fields of cfg for our "base" configuration, and only differ
  72. // peers in each batch.
  73. base := cfg
  74. base.Peers = nil
  75. // Track the known peers so that peer IPs are not replaced if a single
  76. // peer has its allowed IPs split into multiple batches.
  77. knownPeers := make(map[wgtypes.Key]struct{})
  78. batches := make([]wgtypes.Config, 0)
  79. for _, p := range cfg.Peers {
  80. batch := base
  81. // Iterate until no more allowed IPs.
  82. var done bool
  83. for !done {
  84. var tmp []net.IPNet
  85. if len(p.AllowedIPs) < ipBatchChunk {
  86. // IPs all fit within a batch; we are done.
  87. tmp = make([]net.IPNet, len(p.AllowedIPs))
  88. copy(tmp, p.AllowedIPs)
  89. done = true
  90. } else {
  91. // IPs are larger than a single batch, copy a batch out and
  92. // advance the cursor.
  93. tmp = make([]net.IPNet, ipBatchChunk)
  94. copy(tmp, p.AllowedIPs[:ipBatchChunk])
  95. p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
  96. if len(p.AllowedIPs) == 0 {
  97. // IPs ended on a batch boundary; no more IPs left so end
  98. // iteration after this loop.
  99. done = true
  100. }
  101. }
  102. pcfg := wgtypes.PeerConfig{
  103. // PublicKey denotes the peer and must be present.
  104. PublicKey: p.PublicKey,
  105. // Apply the update only flag to every chunk to ensure
  106. // consistency between batches when the kernel module processes
  107. // them.
  108. UpdateOnly: p.UpdateOnly,
  109. // It'd be a bit weird to have a remove peer message with many
  110. // IPs, but just in case, add this to every peer's message.
  111. Remove: p.Remove,
  112. // The IPs for this chunk.
  113. AllowedIPs: tmp,
  114. }
  115. // Only pass certain fields on the first occurrence of a peer, so
  116. // that subsequent IPs won't be wiped out and space isn't wasted.
  117. if _, ok := knownPeers[p.PublicKey]; !ok {
  118. knownPeers[p.PublicKey] = struct{}{}
  119. pcfg.PresharedKey = p.PresharedKey
  120. pcfg.Endpoint = p.Endpoint
  121. pcfg.PersistentKeepaliveInterval = p.PersistentKeepaliveInterval
  122. // Important: do not move or appending peers won't work.
  123. pcfg.ReplaceAllowedIPs = p.ReplaceAllowedIPs
  124. }
  125. // Add a peer configuration to this batch and keep going.
  126. batch.Peers = []wgtypes.PeerConfig{pcfg}
  127. batches = append(batches, batch)
  128. }
  129. }
  130. // Do not allow peer replacement beyond the first message in a batch,
  131. // so we don't overwrite our previous batch work.
  132. for i := range batches {
  133. if i > 0 {
  134. batches[i].ReplacePeers = false
  135. }
  136. }
  137. return batches
  138. }
  139. // encodePeer returns a function to encode PeerConfig nested attributes.
  140. func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
  141. return func(ae *netlink.AttributeEncoder) error {
  142. ae.Bytes(wgh.PeerAPublicKey, p.PublicKey[:])
  143. // Flags are stored in a single attribute.
  144. var flags uint32
  145. if p.Remove {
  146. flags |= wgh.PeerFRemoveMe
  147. }
  148. if p.ReplaceAllowedIPs {
  149. flags |= wgh.PeerFReplaceAllowedips
  150. }
  151. if p.UpdateOnly {
  152. flags |= wgh.PeerFUpdateOnly
  153. }
  154. if flags != 0 {
  155. ae.Uint32(wgh.PeerAFlags, flags)
  156. }
  157. if p.PresharedKey != nil {
  158. ae.Bytes(wgh.PeerAPresharedKey, (*p.PresharedKey)[:])
  159. }
  160. if p.Endpoint != nil {
  161. ae.Do(wgh.PeerAEndpoint, encodeSockaddr(*p.Endpoint))
  162. }
  163. if p.PersistentKeepaliveInterval != nil {
  164. ae.Uint16(wgh.PeerAPersistentKeepaliveInterval, uint16(p.PersistentKeepaliveInterval.Seconds()))
  165. }
  166. // Only apply allowed IPs if necessary.
  167. if len(p.AllowedIPs) > 0 {
  168. ae.Nested(wgh.PeerAAllowedips, encodeAllowedIPs(p.AllowedIPs))
  169. }
  170. return nil
  171. }
  172. }
  173. // encodeSockaddr returns a function which encodes a net.UDPAddr as raw
  174. // sockaddr_in or sockaddr_in6 bytes.
  175. func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
  176. return func() ([]byte, error) {
  177. if !isValidIP(endpoint.IP) {
  178. return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
  179. }
  180. // Is this an IPv6 address?
  181. if isIPv6(endpoint.IP) {
  182. var addr [16]byte
  183. copy(addr[:], endpoint.IP.To16())
  184. sa := unix.RawSockaddrInet6{
  185. Family: unix.AF_INET6,
  186. Port: sockaddrPort(endpoint.Port),
  187. Addr: addr,
  188. }
  189. return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
  190. }
  191. // IPv4 address handling.
  192. var addr [4]byte
  193. copy(addr[:], endpoint.IP.To4())
  194. sa := unix.RawSockaddrInet4{
  195. Family: unix.AF_INET,
  196. Port: sockaddrPort(endpoint.Port),
  197. Addr: addr,
  198. }
  199. return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
  200. }
  201. }
  202. // encodeAllowedIPs returns a function to encode allowed IP nested attributes.
  203. func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
  204. return func(ae *netlink.AttributeEncoder) error {
  205. for i, ipn := range ipns {
  206. if !isValidIP(ipn.IP) {
  207. return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
  208. }
  209. family := uint16(unix.AF_INET6)
  210. if !isIPv6(ipn.IP) {
  211. // Make sure address is 4 bytes if IPv4.
  212. family = unix.AF_INET
  213. ipn.IP = ipn.IP.To4()
  214. }
  215. // Netlink arrays use type as an array index.
  216. ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
  217. nae.Uint16(wgh.AllowedipAFamily, family)
  218. nae.Bytes(wgh.AllowedipAIpaddr, ipn.IP)
  219. ones, _ := ipn.Mask.Size()
  220. nae.Uint8(wgh.AllowedipACidrMask, uint8(ones))
  221. return nil
  222. })
  223. }
  224. return nil
  225. }
  226. }
  227. // isValidIP determines if IP is a valid IPv4 or IPv6 address.
  228. func isValidIP(ip net.IP) bool {
  229. return ip.To16() != nil
  230. }
  231. // isIPv6 determines if IP is a valid IPv6 address.
  232. func isIPv6(ip net.IP) bool {
  233. return isValidIP(ip) && ip.To4() == nil
  234. }
  235. // sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
  236. // structures to the kernel.
  237. func sockaddrPort(port int) uint16 {
  238. return binary.BigEndian.Uint16(nlenc.Uint16Bytes(uint16(port)))
  239. }