iptables.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. // Copyright 2019 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package iptables
  15. import (
  16. "fmt"
  17. "net"
  18. "strings"
  19. "sync"
  20. "time"
  21. "github.com/coreos/go-iptables/iptables"
  22. )
  23. // Protocol represents an IP protocol.
  24. type Protocol byte
  25. const (
  26. // ProtocolIPv4 represents the IPv4 protocol.
  27. ProtocolIPv4 Protocol = iota
  28. // ProtocolIPv6 represents the IPv6 protocol.
  29. ProtocolIPv6
  30. )
  31. // GetProtocol will return a protocol from the length of an IP address.
  32. func GetProtocol(length int) Protocol {
  33. if length == net.IPv6len {
  34. return ProtocolIPv6
  35. }
  36. return ProtocolIPv4
  37. }
  38. // Client represents any type that can administer iptables rules.
  39. type Client interface {
  40. AppendUnique(table string, chain string, rule ...string) error
  41. Delete(table string, chain string, rule ...string) error
  42. Exists(table string, chain string, rule ...string) (bool, error)
  43. ClearChain(table string, chain string) error
  44. DeleteChain(table string, chain string) error
  45. NewChain(table string, chain string) error
  46. }
  47. // Rule is an interface for interacting with iptables objects.
  48. type Rule interface {
  49. Add(Client) error
  50. Delete(Client) error
  51. Exists(Client) (bool, error)
  52. String() string
  53. Proto() Protocol
  54. }
  55. // rule represents an iptables rule.
  56. type rule struct {
  57. table string
  58. chain string
  59. spec []string
  60. proto Protocol
  61. }
  62. // NewRule creates a new iptables or ip6tables rule in the given table and chain
  63. // depending on the given protocol.
  64. func NewRule(proto Protocol, table, chain string, spec ...string) Rule {
  65. return &rule{table, chain, spec, proto}
  66. }
  67. // NewIPv4Rule creates a new iptables rule in the given table and chain.
  68. func NewIPv4Rule(table, chain string, spec ...string) Rule {
  69. return &rule{table, chain, spec, ProtocolIPv4}
  70. }
  71. // NewIPv6Rule creates a new ip6tables rule in the given table and chain.
  72. func NewIPv6Rule(table, chain string, spec ...string) Rule {
  73. return &rule{table, chain, spec, ProtocolIPv6}
  74. }
  75. func (r *rule) Add(client Client) error {
  76. if err := client.AppendUnique(r.table, r.chain, r.spec...); err != nil {
  77. return fmt.Errorf("failed to add iptables rule: %v", err)
  78. }
  79. return nil
  80. }
  81. func (r *rule) Delete(client Client) error {
  82. // Ignore the returned error as an error likely means
  83. // that the rule doesn't exist, which is fine.
  84. client.Delete(r.table, r.chain, r.spec...)
  85. return nil
  86. }
  87. func (r *rule) Exists(client Client) (bool, error) {
  88. return client.Exists(r.table, r.chain, r.spec...)
  89. }
  90. func (r *rule) String() string {
  91. if r == nil {
  92. return ""
  93. }
  94. return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_"))
  95. }
  96. func (r *rule) Proto() Protocol {
  97. return r.proto
  98. }
  99. // chain represents an iptables chain.
  100. type chain struct {
  101. table string
  102. chain string
  103. proto Protocol
  104. }
  105. // NewIPv4Chain creates a new iptables chain in the given table.
  106. func NewIPv4Chain(table, name string) Rule {
  107. return &chain{table, name, ProtocolIPv4}
  108. }
  109. // NewIPv6Chain creates a new ip6tables chain in the given table.
  110. func NewIPv6Chain(table, name string) Rule {
  111. return &chain{table, name, ProtocolIPv6}
  112. }
  113. func (c *chain) Add(client Client) error {
  114. if err := client.ClearChain(c.table, c.chain); err != nil {
  115. return fmt.Errorf("failed to add iptables chain: %v", err)
  116. }
  117. return nil
  118. }
  119. func (c *chain) Delete(client Client) error {
  120. // The chain must be empty before it can be deleted.
  121. if err := client.ClearChain(c.table, c.chain); err != nil {
  122. return fmt.Errorf("failed to clear iptables chain: %v", err)
  123. }
  124. // Ignore the returned error as an error likely means
  125. // that the chain doesn't exist, which is fine.
  126. client.DeleteChain(c.table, c.chain)
  127. return nil
  128. }
  129. func (c *chain) Exists(client Client) (bool, error) {
  130. // The code for "chain already exists".
  131. existsErr := 1
  132. err := client.NewChain(c.table, c.chain)
  133. se, ok := err.(statusExiter)
  134. switch {
  135. case err == nil:
  136. // If there was no error adding a new chain, then it did not exist.
  137. // Delete it and return false.
  138. client.DeleteChain(c.table, c.chain)
  139. return false, nil
  140. case ok && se.ExitStatus() == existsErr:
  141. return true, nil
  142. default:
  143. return false, err
  144. }
  145. }
  146. func (c *chain) String() string {
  147. if c == nil {
  148. return ""
  149. }
  150. return fmt.Sprintf("%s_%s", c.table, c.chain)
  151. }
  152. func (c *chain) Proto() Protocol {
  153. return c.proto
  154. }
  155. // Controller is able to reconcile a given set of iptables rules.
  156. type Controller struct {
  157. v4 Client
  158. v6 Client
  159. errors chan error
  160. sync.Mutex
  161. rules []Rule
  162. subscribed bool
  163. }
  164. // New generates a new iptables rules controller.
  165. // It expects an IP address length to determine
  166. // whether to operate in IPv4 or IPv6 mode.
  167. func New() (*Controller, error) {
  168. v4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
  169. if err != nil {
  170. return nil, fmt.Errorf("failed to create iptables IPv4 client: %v", err)
  171. }
  172. v6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
  173. if err != nil {
  174. return nil, fmt.Errorf("failed to create iptables IPv6 client: %v", err)
  175. }
  176. return &Controller{
  177. v4: v4,
  178. v6: v6,
  179. errors: make(chan error),
  180. }, nil
  181. }
  182. // Run watches for changes to iptables rules and reconciles
  183. // the rules against the desired state.
  184. func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) {
  185. c.Lock()
  186. if c.subscribed {
  187. c.Unlock()
  188. return c.errors, nil
  189. }
  190. // Ensure a given instance only subscribes once.
  191. c.subscribed = true
  192. c.Unlock()
  193. go func() {
  194. defer close(c.errors)
  195. for {
  196. select {
  197. case <-time.After(5 * time.Second):
  198. case <-stop:
  199. return
  200. }
  201. if err := c.reconcile(); err != nil {
  202. nonBlockingSend(c.errors, fmt.Errorf("failed to reconcile rules: %v", err))
  203. }
  204. }
  205. }()
  206. return c.errors, nil
  207. }
  208. // reconcile makes sure that every rule is still in the backend.
  209. // It does not ensure that the order in the backend is correct.
  210. // If any rule is missing, that rule and all following rules are
  211. // re-added.
  212. func (c *Controller) reconcile() error {
  213. c.Lock()
  214. defer c.Unlock()
  215. for i, r := range c.rules {
  216. ok, err := r.Exists(c.client(r.Proto()))
  217. if err != nil {
  218. return fmt.Errorf("failed to check if rule exists: %v", err)
  219. }
  220. if !ok {
  221. if err := c.resetFromIndex(i, c.rules); err != nil {
  222. return fmt.Errorf("failed to add rule: %v", err)
  223. }
  224. break
  225. }
  226. }
  227. return nil
  228. }
  229. // resetFromIndex re-adds all rules starting from the given index.
  230. func (c *Controller) resetFromIndex(i int, rules []Rule) error {
  231. if i >= len(rules) {
  232. return nil
  233. }
  234. for j := i; j < len(rules); j++ {
  235. if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil {
  236. return fmt.Errorf("failed to delete rule: %v", err)
  237. }
  238. if err := rules[j].Add(c.client(rules[j].Proto())); err != nil {
  239. return fmt.Errorf("failed to add rule: %v", err)
  240. }
  241. }
  242. return nil
  243. }
  244. // deleteFromIndex deletes all rules starting from the given index.
  245. func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error {
  246. if i >= len(*rules) {
  247. return nil
  248. }
  249. for j := i; j < len(*rules); j++ {
  250. if err := (*rules)[j].Delete(c.client((*rules)[j].Proto())); err != nil {
  251. *rules = append((*rules)[:i], (*rules)[j:]...)
  252. return fmt.Errorf("failed to delete rule: %v", err)
  253. }
  254. (*rules)[j] = nil
  255. }
  256. *rules = (*rules)[:i]
  257. return nil
  258. }
  259. // Set idempotently overwrites any iptables rules previously defined
  260. // for the controller with the given set of rules.
  261. func (c *Controller) Set(rules []Rule) error {
  262. c.Lock()
  263. defer c.Unlock()
  264. var i int
  265. for ; i < len(rules); i++ {
  266. if i < len(c.rules) {
  267. if rules[i].String() != c.rules[i].String() {
  268. if err := c.deleteFromIndex(i, &c.rules); err != nil {
  269. return err
  270. }
  271. }
  272. }
  273. if i >= len(c.rules) {
  274. if err := rules[i].Add(c.client(rules[i].Proto())); err != nil {
  275. return fmt.Errorf("failed to add rule: %v", err)
  276. }
  277. c.rules = append(c.rules, rules[i])
  278. }
  279. }
  280. return c.deleteFromIndex(i, &c.rules)
  281. }
  282. // CleanUp will clean up any rules created by the controller.
  283. func (c *Controller) CleanUp() error {
  284. c.Lock()
  285. defer c.Unlock()
  286. return c.deleteFromIndex(0, &c.rules)
  287. }
  288. func (c *Controller) client(p Protocol) Client {
  289. switch p {
  290. case ProtocolIPv4:
  291. return c.v4
  292. case ProtocolIPv6:
  293. return c.v6
  294. default:
  295. panic("unknown protocol")
  296. }
  297. }
  298. func nonBlockingSend(errors chan<- error, err error) {
  299. select {
  300. case errors <- err:
  301. default:
  302. }
  303. }