parse_linux.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. //go:build linux
  2. // +build linux
  3. package wglinux
  4. import (
  5. "fmt"
  6. "net"
  7. "time"
  8. "unsafe"
  9. "github.com/mdlayher/genetlink"
  10. "github.com/mdlayher/netlink"
  11. "golang.org/x/sys/unix"
  12. "golang.zx2c4.com/wireguard/wgctrl/internal/wglinux/internal/wgh"
  13. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  14. )
  15. // parseDevice parses a Device from a slice of generic netlink messages,
  16. // automatically merging peer lists from subsequent messages into the Device
  17. // from the first message.
  18. func parseDevice(msgs []genetlink.Message) (*wgtypes.Device, error) {
  19. var first wgtypes.Device
  20. knownPeers := make(map[wgtypes.Key]int)
  21. for i, m := range msgs {
  22. d, err := parseDeviceLoop(m)
  23. if err != nil {
  24. return nil, err
  25. }
  26. if i == 0 {
  27. // First message contains our target device.
  28. first = *d
  29. // Gather the known peers so that we can merge
  30. // them later if needed
  31. for i := range first.Peers {
  32. knownPeers[first.Peers[i].PublicKey] = i
  33. }
  34. continue
  35. }
  36. // Any subsequent messages have their peer contents merged into the
  37. // first "target" message.
  38. mergeDevices(&first, d, knownPeers)
  39. }
  40. return &first, nil
  41. }
  42. // parseDeviceLoop parses a Device from a single generic netlink message.
  43. func parseDeviceLoop(m genetlink.Message) (*wgtypes.Device, error) {
  44. ad, err := netlink.NewAttributeDecoder(m.Data)
  45. if err != nil {
  46. return nil, err
  47. }
  48. d := wgtypes.Device{Type: wgtypes.LinuxKernel}
  49. for ad.Next() {
  50. switch ad.Type() {
  51. case wgh.DeviceAIfindex:
  52. // Ignored; interface index isn't exposed at all in the userspace
  53. // configuration protocol, and name is more friendly anyway.
  54. case wgh.DeviceAIfname:
  55. d.Name = ad.String()
  56. case wgh.DeviceAPrivateKey:
  57. ad.Do(parseKey(&d.PrivateKey))
  58. case wgh.DeviceAPublicKey:
  59. ad.Do(parseKey(&d.PublicKey))
  60. case wgh.DeviceAListenPort:
  61. d.ListenPort = int(ad.Uint16())
  62. case wgh.DeviceAFwmark:
  63. d.FirewallMark = int(ad.Uint32())
  64. case wgh.DeviceAPeers:
  65. // Netlink array of peers.
  66. //
  67. // Errors while parsing are propagated up to top-level ad.Err check.
  68. ad.Nested(func(nad *netlink.AttributeDecoder) error {
  69. // Initialize to the number of peers in this decoder and begin
  70. // handling nested Peer attributes.
  71. d.Peers = make([]wgtypes.Peer, 0, nad.Len())
  72. for nad.Next() {
  73. nad.Nested(func(nnad *netlink.AttributeDecoder) error {
  74. d.Peers = append(d.Peers, parsePeer(nnad))
  75. return nil
  76. })
  77. }
  78. return nil
  79. })
  80. }
  81. }
  82. if err := ad.Err(); err != nil {
  83. return nil, err
  84. }
  85. return &d, nil
  86. }
  87. // parseAllowedIPs parses a wgtypes.Peer from a netlink attribute payload.
  88. func parsePeer(ad *netlink.AttributeDecoder) wgtypes.Peer {
  89. var p wgtypes.Peer
  90. for ad.Next() {
  91. switch ad.Type() {
  92. case wgh.PeerAPublicKey:
  93. ad.Do(parseKey(&p.PublicKey))
  94. case wgh.PeerAPresharedKey:
  95. ad.Do(parseKey(&p.PresharedKey))
  96. case wgh.PeerAEndpoint:
  97. p.Endpoint = &net.UDPAddr{}
  98. ad.Do(parseSockaddr(p.Endpoint))
  99. case wgh.PeerAPersistentKeepaliveInterval:
  100. p.PersistentKeepaliveInterval = time.Duration(ad.Uint16()) * time.Second
  101. case wgh.PeerALastHandshakeTime:
  102. ad.Do(parseTimespec(&p.LastHandshakeTime))
  103. case wgh.PeerARxBytes:
  104. p.ReceiveBytes = int64(ad.Uint64())
  105. case wgh.PeerATxBytes:
  106. p.TransmitBytes = int64(ad.Uint64())
  107. case wgh.PeerAAllowedips:
  108. ad.Nested(parseAllowedIPs(&p.AllowedIPs))
  109. case wgh.PeerAProtocolVersion:
  110. p.ProtocolVersion = int(ad.Uint32())
  111. }
  112. }
  113. return p
  114. }
  115. // parseAllowedIPs parses a slice of net.IPNet from a netlink attribute payload.
  116. func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error {
  117. return func(ad *netlink.AttributeDecoder) error {
  118. // Initialize to the number of allowed IPs and begin iterating through
  119. // the netlink array to decode each one.
  120. *ipns = make([]net.IPNet, 0, ad.Len())
  121. for ad.Next() {
  122. // Allowed IP nested attributes.
  123. ad.Nested(func(nad *netlink.AttributeDecoder) error {
  124. var (
  125. ipn net.IPNet
  126. mask int
  127. family int
  128. )
  129. for nad.Next() {
  130. switch nad.Type() {
  131. case wgh.AllowedipAIpaddr:
  132. nad.Do(parseAddr(&ipn.IP))
  133. case wgh.AllowedipACidrMask:
  134. mask = int(nad.Uint8())
  135. case wgh.AllowedipAFamily:
  136. family = int(nad.Uint16())
  137. }
  138. }
  139. if err := nad.Err(); err != nil {
  140. return err
  141. }
  142. // The address family determines the correct number of bits in
  143. // the mask.
  144. switch family {
  145. case unix.AF_INET:
  146. ipn.Mask = net.CIDRMask(mask, 32)
  147. case unix.AF_INET6:
  148. ipn.Mask = net.CIDRMask(mask, 128)
  149. }
  150. *ipns = append(*ipns, ipn)
  151. return nil
  152. })
  153. }
  154. return nil
  155. }
  156. }
  157. // parseKey parses a wgtypes.Key from a byte slice.
  158. func parseKey(key *wgtypes.Key) func(b []byte) error {
  159. return func(b []byte) error {
  160. k, err := wgtypes.NewKey(b)
  161. if err != nil {
  162. return err
  163. }
  164. *key = k
  165. return nil
  166. }
  167. }
  168. // parseAddr parses a net.IP from raw in_addr or in6_addr struct bytes.
  169. func parseAddr(ip *net.IP) func(b []byte) error {
  170. return func(b []byte) error {
  171. switch len(b) {
  172. case net.IPv4len, net.IPv6len:
  173. // Okay to convert directly to net.IP; memory layout is identical.
  174. *ip = make(net.IP, len(b))
  175. copy(*ip, b)
  176. return nil
  177. default:
  178. return fmt.Errorf("wglinux: unexpected IP address size: %d", len(b))
  179. }
  180. }
  181. }
  182. // parseSockaddr parses a *net.UDPAddr from raw sockaddr_in or sockaddr_in6 bytes.
  183. func parseSockaddr(endpoint *net.UDPAddr) func(b []byte) error {
  184. return func(b []byte) error {
  185. switch len(b) {
  186. case unix.SizeofSockaddrInet4:
  187. // IPv4 address parsing.
  188. sa := *(*unix.RawSockaddrInet4)(unsafe.Pointer(&b[0]))
  189. *endpoint = net.UDPAddr{
  190. IP: net.IP(sa.Addr[:]).To4(),
  191. Port: int(sockaddrPort(int(sa.Port))),
  192. }
  193. return nil
  194. case unix.SizeofSockaddrInet6:
  195. // IPv6 address parsing.
  196. sa := *(*unix.RawSockaddrInet6)(unsafe.Pointer(&b[0]))
  197. *endpoint = net.UDPAddr{
  198. IP: net.IP(sa.Addr[:]),
  199. Port: int(sockaddrPort(int(sa.Port))),
  200. }
  201. return nil
  202. default:
  203. return fmt.Errorf("wglinux: unexpected sockaddr size: %d", len(b))
  204. }
  205. }
  206. }
  207. // timespec32 is a unix.Timespec with 32-bit integers.
  208. type timespec32 struct {
  209. Sec int32
  210. Nsec int32
  211. }
  212. // timespec64 is a unix.Timespec with 64-bit integers.
  213. type timespec64 struct {
  214. Sec int64
  215. Nsec int64
  216. }
  217. const (
  218. sizeofTimespec32 = int(unsafe.Sizeof(timespec32{}))
  219. sizeofTimespec64 = int(unsafe.Sizeof(timespec64{}))
  220. )
  221. // parseTimespec parses a time.Time from raw timespec bytes.
  222. func parseTimespec(t *time.Time) func(b []byte) error {
  223. return func(b []byte) error {
  224. // It would appear that WireGuard can return a __kernel_timespec which
  225. // uses 64-bit integers, even on 32-bit platforms. Clarification of this
  226. // behavior is being sought in:
  227. // https://lists.zx2c4.com/pipermail/wireguard/2019-April/004088.html.
  228. //
  229. // In the mean time, be liberal and accept 32-bit and 64-bit variants.
  230. var sec, nsec int64
  231. switch len(b) {
  232. case sizeofTimespec32:
  233. ts := *(*timespec32)(unsafe.Pointer(&b[0]))
  234. sec = int64(ts.Sec)
  235. nsec = int64(ts.Nsec)
  236. case sizeofTimespec64:
  237. ts := *(*timespec64)(unsafe.Pointer(&b[0]))
  238. sec = ts.Sec
  239. nsec = ts.Nsec
  240. default:
  241. return fmt.Errorf("wglinux: unexpected timespec size: %d bytes, expected 8 or 16 bytes", len(b))
  242. }
  243. // Only set fields if UNIX timestamp value is greater than 0, so the
  244. // caller will see a zero-value time.Time otherwise.
  245. if sec > 0 || nsec > 0 {
  246. *t = time.Unix(sec, nsec)
  247. }
  248. return nil
  249. }
  250. }
  251. // mergeDevices merges Peer information from d into target. mergeDevices is
  252. // used to deal with multiple incoming netlink messages for the same device.
  253. func mergeDevices(target, d *wgtypes.Device, knownPeers map[wgtypes.Key]int) {
  254. for i := range d.Peers {
  255. // Peer is already known, append to it's allowed IP networks
  256. if peerIndex, ok := knownPeers[d.Peers[i].PublicKey]; ok {
  257. target.Peers[peerIndex].AllowedIPs = append(target.Peers[peerIndex].AllowedIPs, d.Peers[i].AllowedIPs...)
  258. } else { // New peer, add it to the target peers.
  259. target.Peers = append(target.Peers, d.Peers[i])
  260. knownPeers[d.Peers[i].PublicKey] = len(target.Peers) - 1
  261. }
  262. }
  263. }