conn_linux.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. //+build linux
  2. package netlink
  3. import (
  4. "os"
  5. "runtime"
  6. "syscall"
  7. "time"
  8. "unsafe"
  9. "github.com/mdlayher/socket"
  10. "golang.org/x/net/bpf"
  11. "golang.org/x/sys/unix"
  12. )
  13. var _ Socket = &conn{}
  14. // A conn is the Linux implementation of a netlink sockets connection.
  15. type conn struct {
  16. s *socket.Conn
  17. }
  18. // dial is the entry point for Dial. dial opens a netlink socket using
  19. // system calls, and returns its PID.
  20. func dial(family int, config *Config) (*conn, uint32, error) {
  21. if config == nil {
  22. config = &Config{}
  23. }
  24. // The caller has indicated it wants the netlink socket to be created
  25. // inside another network namespace.
  26. if config.NetNS != 0 {
  27. runtime.LockOSThread()
  28. defer runtime.UnlockOSThread()
  29. // Retrieve and store the calling OS thread's network namespace so
  30. // the thread can be reassigned to it after creating a socket in another
  31. // network namespace.
  32. threadNS, err := threadNetNS()
  33. if err != nil {
  34. return nil, 0, err
  35. }
  36. // Always close the netns handle created above.
  37. defer threadNS.Close()
  38. // Assign the current OS thread the goroutine is locked to to the given
  39. // network namespace.
  40. if err := threadNS.Set(config.NetNS); err != nil {
  41. return nil, 0, err
  42. }
  43. // Thread's namespace has been successfully set. Return the thread
  44. // back to its original namespace after attempting to create the
  45. // netlink socket.
  46. defer threadNS.Restore()
  47. }
  48. // Prepare the netlink socket.
  49. s, err := socket.Socket(unix.AF_NETLINK, unix.SOCK_RAW, family, "netlink")
  50. if err != nil {
  51. return nil, 0, err
  52. }
  53. return newConn(s, config)
  54. }
  55. // newConn binds a connection to netlink using the input *socket.Conn.
  56. func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) {
  57. if config == nil {
  58. config = &Config{}
  59. }
  60. addr := &unix.SockaddrNetlink{
  61. Family: unix.AF_NETLINK,
  62. Groups: config.Groups,
  63. }
  64. // Socket must be closed in the event of any system call errors, to avoid
  65. // leaking file descriptors.
  66. if err := s.Bind(addr); err != nil {
  67. _ = s.Close()
  68. return nil, 0, err
  69. }
  70. sa, err := s.Getsockname()
  71. if err != nil {
  72. _ = s.Close()
  73. return nil, 0, err
  74. }
  75. return &conn{
  76. s: s,
  77. }, sa.(*unix.SockaddrNetlink).Pid, nil
  78. }
  79. // SendMessages serializes multiple Messages and sends them to netlink.
  80. func (c *conn) SendMessages(messages []Message) error {
  81. var buf []byte
  82. for _, m := range messages {
  83. b, err := m.MarshalBinary()
  84. if err != nil {
  85. return err
  86. }
  87. buf = append(buf, b...)
  88. }
  89. sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
  90. return c.s.Sendmsg(buf, nil, sa, 0)
  91. }
  92. // Send sends a single Message to netlink.
  93. func (c *conn) Send(m Message) error {
  94. b, err := m.MarshalBinary()
  95. if err != nil {
  96. return err
  97. }
  98. sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK}
  99. return c.s.Sendmsg(b, nil, sa, 0)
  100. }
  101. // Receive receives one or more Messages from netlink.
  102. func (c *conn) Receive() ([]Message, error) {
  103. b := make([]byte, os.Getpagesize())
  104. for {
  105. // Peek at the buffer to see how many bytes are available.
  106. //
  107. // TODO(mdlayher): deal with OOB message data if available, such as
  108. // when PacketInfo ConnOption is true.
  109. n, _, _, _, err := c.s.Recvmsg(b, nil, unix.MSG_PEEK)
  110. if err != nil {
  111. return nil, err
  112. }
  113. // Break when we can read all messages
  114. if n < len(b) {
  115. break
  116. }
  117. // Double in size if not enough bytes
  118. b = make([]byte, len(b)*2)
  119. }
  120. // Read out all available messages
  121. n, _, _, _, err := c.s.Recvmsg(b, nil, 0)
  122. if err != nil {
  123. return nil, err
  124. }
  125. raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)])
  126. if err != nil {
  127. return nil, err
  128. }
  129. msgs := make([]Message, 0, len(raw))
  130. for _, r := range raw {
  131. m := Message{
  132. Header: sysToHeader(r.Header),
  133. Data: r.Data,
  134. }
  135. msgs = append(msgs, m)
  136. }
  137. return msgs, nil
  138. }
  139. // Close closes the connection.
  140. func (c *conn) Close() error { return c.s.Close() }
  141. // JoinGroup joins a multicast group by ID.
  142. func (c *conn) JoinGroup(group uint32) error {
  143. return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group))
  144. }
  145. // LeaveGroup leaves a multicast group by ID.
  146. func (c *conn) LeaveGroup(group uint32) error {
  147. return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group))
  148. }
  149. // SetBPF attaches an assembled BPF program to a conn.
  150. func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) }
  151. // RemoveBPF removes a BPF filter from a conn.
  152. func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() }
  153. // SetOption enables or disables a netlink socket option for the Conn.
  154. func (c *conn) SetOption(option ConnOption, enable bool) error {
  155. o, ok := linuxOption(option)
  156. if !ok {
  157. // Return the typical Linux error for an unknown ConnOption.
  158. return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT)
  159. }
  160. var v int
  161. if enable {
  162. v = 1
  163. }
  164. return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v)
  165. }
  166. func (c *conn) SetDeadline(t time.Time) error { return c.s.SetDeadline(t) }
  167. func (c *conn) SetReadDeadline(t time.Time) error { return c.s.SetReadDeadline(t) }
  168. func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) }
  169. // SetReadBuffer sets the size of the operating system's receive buffer
  170. // associated with the Conn.
  171. func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) }
  172. // SetReadBuffer sets the size of the operating system's transmit buffer
  173. // associated with the Conn.
  174. func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) }
  175. // SyscallConn returns a raw network connection.
  176. func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() }
  177. // linuxOption converts a ConnOption to its Linux value.
  178. func linuxOption(o ConnOption) (int, bool) {
  179. switch o {
  180. case PacketInfo:
  181. return unix.NETLINK_PKTINFO, true
  182. case BroadcastError:
  183. return unix.NETLINK_BROADCAST_ERROR, true
  184. case NoENOBUFS:
  185. return unix.NETLINK_NO_ENOBUFS, true
  186. case ListenAllNSID:
  187. return unix.NETLINK_LISTEN_ALL_NSID, true
  188. case CapAcknowledge:
  189. return unix.NETLINK_CAP_ACK, true
  190. case ExtendedAcknowledge:
  191. return unix.NETLINK_EXT_ACK, true
  192. case GetStrictCheck:
  193. return unix.NETLINK_GET_STRICT_CHK, true
  194. default:
  195. return 0, false
  196. }
  197. }
  198. // sysToHeader converts a syscall.NlMsghdr to a Header.
  199. func sysToHeader(r syscall.NlMsghdr) Header {
  200. // NB: the memory layout of Header and syscall.NlMsgHdr must be
  201. // exactly the same for this unsafe cast to work
  202. return *(*Header)(unsafe.Pointer(&r))
  203. }
  204. // newError converts an error number from netlink into the appropriate
  205. // system call error for Linux.
  206. func newError(errno int) error {
  207. return syscall.Errno(errno)
  208. }