rule_linux.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package netlink
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net"
  6. "github.com/vishvananda/netlink/nl"
  7. "golang.org/x/sys/unix"
  8. )
  9. const FibRuleInvert = 0x2
  10. // RuleAdd adds a rule to the system.
  11. // Equivalent to: ip rule add
  12. func RuleAdd(rule *Rule) error {
  13. return pkgHandle.RuleAdd(rule)
  14. }
  15. // RuleAdd adds a rule to the system.
  16. // Equivalent to: ip rule add
  17. func (h *Handle) RuleAdd(rule *Rule) error {
  18. req := h.newNetlinkRequest(unix.RTM_NEWRULE, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
  19. return ruleHandle(rule, req)
  20. }
  21. // RuleDel deletes a rule from the system.
  22. // Equivalent to: ip rule del
  23. func RuleDel(rule *Rule) error {
  24. return pkgHandle.RuleDel(rule)
  25. }
  26. // RuleDel deletes a rule from the system.
  27. // Equivalent to: ip rule del
  28. func (h *Handle) RuleDel(rule *Rule) error {
  29. req := h.newNetlinkRequest(unix.RTM_DELRULE, unix.NLM_F_ACK)
  30. return ruleHandle(rule, req)
  31. }
  32. func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
  33. msg := nl.NewRtMsg()
  34. msg.Family = unix.AF_INET
  35. msg.Protocol = unix.RTPROT_BOOT
  36. msg.Scope = unix.RT_SCOPE_UNIVERSE
  37. msg.Table = unix.RT_TABLE_UNSPEC
  38. msg.Type = unix.RTN_UNSPEC
  39. if req.NlMsghdr.Flags&unix.NLM_F_CREATE > 0 {
  40. msg.Type = unix.RTN_UNICAST
  41. }
  42. if rule.Invert {
  43. msg.Flags |= FibRuleInvert
  44. }
  45. if rule.Family != 0 {
  46. msg.Family = uint8(rule.Family)
  47. }
  48. if rule.Table >= 0 && rule.Table < 256 {
  49. msg.Table = uint8(rule.Table)
  50. }
  51. if rule.Tos != 0 {
  52. msg.Tos = uint8(rule.Tos)
  53. }
  54. var dstFamily uint8
  55. var rtAttrs []*nl.RtAttr
  56. if rule.Dst != nil && rule.Dst.IP != nil {
  57. dstLen, _ := rule.Dst.Mask.Size()
  58. msg.Dst_len = uint8(dstLen)
  59. msg.Family = uint8(nl.GetIPFamily(rule.Dst.IP))
  60. dstFamily = msg.Family
  61. var dstData []byte
  62. if msg.Family == unix.AF_INET {
  63. dstData = rule.Dst.IP.To4()
  64. } else {
  65. dstData = rule.Dst.IP.To16()
  66. }
  67. rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
  68. }
  69. if rule.Src != nil && rule.Src.IP != nil {
  70. msg.Family = uint8(nl.GetIPFamily(rule.Src.IP))
  71. if dstFamily != 0 && dstFamily != msg.Family {
  72. return fmt.Errorf("source and destination ip are not the same IP family")
  73. }
  74. srcLen, _ := rule.Src.Mask.Size()
  75. msg.Src_len = uint8(srcLen)
  76. var srcData []byte
  77. if msg.Family == unix.AF_INET {
  78. srcData = rule.Src.IP.To4()
  79. } else {
  80. srcData = rule.Src.IP.To16()
  81. }
  82. rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_SRC, srcData))
  83. }
  84. req.AddData(msg)
  85. for i := range rtAttrs {
  86. req.AddData(rtAttrs[i])
  87. }
  88. native := nl.NativeEndian()
  89. if rule.Priority >= 0 {
  90. b := make([]byte, 4)
  91. native.PutUint32(b, uint32(rule.Priority))
  92. req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
  93. }
  94. if rule.Mark >= 0 {
  95. b := make([]byte, 4)
  96. native.PutUint32(b, uint32(rule.Mark))
  97. req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
  98. }
  99. if rule.Mask >= 0 {
  100. b := make([]byte, 4)
  101. native.PutUint32(b, uint32(rule.Mask))
  102. req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
  103. }
  104. if rule.Flow >= 0 {
  105. b := make([]byte, 4)
  106. native.PutUint32(b, uint32(rule.Flow))
  107. req.AddData(nl.NewRtAttr(nl.FRA_FLOW, b))
  108. }
  109. if rule.TunID > 0 {
  110. b := make([]byte, 4)
  111. native.PutUint32(b, uint32(rule.TunID))
  112. req.AddData(nl.NewRtAttr(nl.FRA_TUN_ID, b))
  113. }
  114. if rule.Table >= 256 {
  115. b := make([]byte, 4)
  116. native.PutUint32(b, uint32(rule.Table))
  117. req.AddData(nl.NewRtAttr(nl.FRA_TABLE, b))
  118. }
  119. if msg.Table > 0 {
  120. if rule.SuppressPrefixlen >= 0 {
  121. b := make([]byte, 4)
  122. native.PutUint32(b, uint32(rule.SuppressPrefixlen))
  123. req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_PREFIXLEN, b))
  124. }
  125. if rule.SuppressIfgroup >= 0 {
  126. b := make([]byte, 4)
  127. native.PutUint32(b, uint32(rule.SuppressIfgroup))
  128. req.AddData(nl.NewRtAttr(nl.FRA_SUPPRESS_IFGROUP, b))
  129. }
  130. }
  131. if rule.IifName != "" {
  132. req.AddData(nl.NewRtAttr(nl.FRA_IIFNAME, []byte(rule.IifName+"\x00")))
  133. }
  134. if rule.OifName != "" {
  135. req.AddData(nl.NewRtAttr(nl.FRA_OIFNAME, []byte(rule.OifName+"\x00")))
  136. }
  137. if rule.Goto >= 0 {
  138. msg.Type = nl.FR_ACT_GOTO
  139. b := make([]byte, 4)
  140. native.PutUint32(b, uint32(rule.Goto))
  141. req.AddData(nl.NewRtAttr(nl.FRA_GOTO, b))
  142. }
  143. if rule.Dport != nil {
  144. b := rule.Dport.toRtAttrData()
  145. req.AddData(nl.NewRtAttr(nl.FRA_DPORT_RANGE, b))
  146. }
  147. if rule.Sport != nil {
  148. b := rule.Sport.toRtAttrData()
  149. req.AddData(nl.NewRtAttr(nl.FRA_SPORT_RANGE, b))
  150. }
  151. _, err := req.Execute(unix.NETLINK_ROUTE, 0)
  152. return err
  153. }
  154. // RuleList lists rules in the system.
  155. // Equivalent to: ip rule list
  156. func RuleList(family int) ([]Rule, error) {
  157. return pkgHandle.RuleList(family)
  158. }
  159. // RuleList lists rules in the system.
  160. // Equivalent to: ip rule list
  161. func (h *Handle) RuleList(family int) ([]Rule, error) {
  162. return h.RuleListFiltered(family, nil, 0)
  163. }
  164. // RuleListFiltered gets a list of rules in the system filtered by the
  165. // specified rule template `filter`.
  166. // Equivalent to: ip rule list
  167. func RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
  168. return pkgHandle.RuleListFiltered(family, filter, filterMask)
  169. }
  170. // RuleListFiltered lists rules in the system.
  171. // Equivalent to: ip rule list
  172. func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ([]Rule, error) {
  173. req := h.newNetlinkRequest(unix.RTM_GETRULE, unix.NLM_F_DUMP|unix.NLM_F_REQUEST)
  174. msg := nl.NewIfInfomsg(family)
  175. req.AddData(msg)
  176. msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWRULE)
  177. if err != nil {
  178. return nil, err
  179. }
  180. native := nl.NativeEndian()
  181. var res = make([]Rule, 0)
  182. for i := range msgs {
  183. msg := nl.DeserializeRtMsg(msgs[i])
  184. attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
  185. if err != nil {
  186. return nil, err
  187. }
  188. rule := NewRule()
  189. rule.Invert = msg.Flags&FibRuleInvert > 0
  190. rule.Tos = uint(msg.Tos)
  191. for j := range attrs {
  192. switch attrs[j].Attr.Type {
  193. case unix.RTA_TABLE:
  194. rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
  195. case nl.FRA_SRC:
  196. rule.Src = &net.IPNet{
  197. IP: attrs[j].Value,
  198. Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
  199. }
  200. case nl.FRA_DST:
  201. rule.Dst = &net.IPNet{
  202. IP: attrs[j].Value,
  203. Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
  204. }
  205. case nl.FRA_FWMARK:
  206. rule.Mark = int(native.Uint32(attrs[j].Value[0:4]))
  207. case nl.FRA_FWMASK:
  208. rule.Mask = int(native.Uint32(attrs[j].Value[0:4]))
  209. case nl.FRA_TUN_ID:
  210. rule.TunID = uint(native.Uint64(attrs[j].Value[0:4]))
  211. case nl.FRA_IIFNAME:
  212. rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
  213. case nl.FRA_OIFNAME:
  214. rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
  215. case nl.FRA_SUPPRESS_PREFIXLEN:
  216. i := native.Uint32(attrs[j].Value[0:4])
  217. if i != 0xffffffff {
  218. rule.SuppressPrefixlen = int(i)
  219. }
  220. case nl.FRA_SUPPRESS_IFGROUP:
  221. i := native.Uint32(attrs[j].Value[0:4])
  222. if i != 0xffffffff {
  223. rule.SuppressIfgroup = int(i)
  224. }
  225. case nl.FRA_FLOW:
  226. rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
  227. case nl.FRA_GOTO:
  228. rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
  229. case nl.FRA_PRIORITY:
  230. rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
  231. case nl.FRA_DPORT_RANGE:
  232. rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
  233. case nl.FRA_SPORT_RANGE:
  234. rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
  235. }
  236. }
  237. if filter != nil {
  238. switch {
  239. case filterMask&RT_FILTER_SRC != 0 &&
  240. (rule.Src == nil || rule.Src.String() != filter.Src.String()):
  241. continue
  242. case filterMask&RT_FILTER_DST != 0 &&
  243. (rule.Dst == nil || rule.Dst.String() != filter.Dst.String()):
  244. continue
  245. case filterMask&RT_FILTER_TABLE != 0 &&
  246. filter.Table != unix.RT_TABLE_UNSPEC && rule.Table != filter.Table:
  247. continue
  248. case filterMask&RT_FILTER_TOS != 0 && rule.Tos != filter.Tos:
  249. continue
  250. case filterMask&RT_FILTER_PRIORITY != 0 && rule.Priority != filter.Priority:
  251. continue
  252. case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
  253. continue
  254. case filterMask&RT_FILTER_MASK != 0 && rule.Mask != filter.Mask:
  255. continue
  256. }
  257. }
  258. res = append(res, *rule)
  259. }
  260. return res, nil
  261. }
  262. func (pr *RulePortRange) toRtAttrData() []byte {
  263. b := [][]byte{make([]byte, 2), make([]byte, 2)}
  264. native.PutUint16(b[0], pr.Start)
  265. native.PutUint16(b[1], pr.End)
  266. return bytes.Join(b, []byte{})
  267. }