family_linux.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. //+build linux
  2. package genetlink
  3. import (
  4. "errors"
  5. "fmt"
  6. "math"
  7. "github.com/mdlayher/netlink"
  8. "github.com/mdlayher/netlink/nlenc"
  9. "golang.org/x/sys/unix"
  10. )
  11. // errInvalidFamilyVersion is returned when a family's version is greater
  12. // than an 8-bit integer.
  13. var errInvalidFamilyVersion = errors.New("invalid family version attribute")
  14. // getFamily retrieves a generic netlink family with the specified name.
  15. func (c *Conn) getFamily(name string) (Family, error) {
  16. b, err := netlink.MarshalAttributes([]netlink.Attribute{{
  17. Type: unix.CTRL_ATTR_FAMILY_NAME,
  18. Data: nlenc.Bytes(name),
  19. }})
  20. if err != nil {
  21. return Family{}, err
  22. }
  23. req := Message{
  24. Header: Header{
  25. Command: unix.CTRL_CMD_GETFAMILY,
  26. // TODO(mdlayher): grab nlctrl version?
  27. Version: 1,
  28. },
  29. Data: b,
  30. }
  31. msgs, err := c.Execute(req, unix.GENL_ID_CTRL, netlink.Request)
  32. if err != nil {
  33. return Family{}, err
  34. }
  35. // TODO(mdlayher): consider interpreting generic netlink header values
  36. families, err := buildFamilies(msgs)
  37. if err != nil {
  38. return Family{}, err
  39. }
  40. if len(families) != 1 {
  41. // If this were to ever happen, netlink must be in a state where
  42. // its answers cannot be trusted
  43. panic(fmt.Sprintf("netlink returned multiple families for name: %q", name))
  44. }
  45. return families[0], nil
  46. }
  47. // listFamilies retrieves all registered generic netlink families.
  48. func (c *Conn) listFamilies() ([]Family, error) {
  49. req := Message{
  50. Header: Header{
  51. Command: unix.CTRL_CMD_GETFAMILY,
  52. // TODO(mdlayher): grab nlctrl version?
  53. Version: 1,
  54. },
  55. }
  56. flags := netlink.Request | netlink.Dump
  57. msgs, err := c.Execute(req, unix.GENL_ID_CTRL, flags)
  58. if err != nil {
  59. return nil, err
  60. }
  61. return buildFamilies(msgs)
  62. }
  63. // buildFamilies builds a slice of Families by parsing attributes from the
  64. // input Messages.
  65. func buildFamilies(msgs []Message) ([]Family, error) {
  66. families := make([]Family, 0, len(msgs))
  67. for _, m := range msgs {
  68. var f Family
  69. if err := (&f).parseAttributes(m.Data); err != nil {
  70. return nil, err
  71. }
  72. families = append(families, f)
  73. }
  74. return families, nil
  75. }
  76. // parseAttributes decodes netlink attributes into a Family's fields.
  77. func (f *Family) parseAttributes(b []byte) error {
  78. ad, err := netlink.NewAttributeDecoder(b)
  79. if err != nil {
  80. return err
  81. }
  82. for ad.Next() {
  83. switch ad.Type() {
  84. case unix.CTRL_ATTR_FAMILY_ID:
  85. f.ID = ad.Uint16()
  86. case unix.CTRL_ATTR_FAMILY_NAME:
  87. f.Name = ad.String()
  88. case unix.CTRL_ATTR_VERSION:
  89. v := ad.Uint32()
  90. if v > math.MaxUint8 {
  91. return errInvalidFamilyVersion
  92. }
  93. f.Version = uint8(v)
  94. case unix.CTRL_ATTR_MCAST_GROUPS:
  95. ad.Nested(func(nad *netlink.AttributeDecoder) error {
  96. f.Groups = parseMulticastGroups(nad)
  97. return nil
  98. })
  99. }
  100. }
  101. return ad.Err()
  102. }
  103. // parseMulticastGroups parses an array of multicast group nested attributes
  104. // into a slice of MulticastGroups.
  105. func parseMulticastGroups(ad *netlink.AttributeDecoder) []MulticastGroup {
  106. groups := make([]MulticastGroup, 0, ad.Len())
  107. for ad.Next() {
  108. ad.Nested(func(nad *netlink.AttributeDecoder) error {
  109. var g MulticastGroup
  110. for nad.Next() {
  111. switch nad.Type() {
  112. case unix.CTRL_ATTR_MCAST_GRP_NAME:
  113. g.Name = nad.String()
  114. case unix.CTRL_ATTR_MCAST_GRP_ID:
  115. g.ID = nad.Uint32()
  116. }
  117. }
  118. groups = append(groups, g)
  119. return nil
  120. })
  121. }
  122. return groups
  123. }