ipset.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. // Copyright 2019 the Kilo authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ipset
  15. import (
  16. "bytes"
  17. "fmt"
  18. "net"
  19. "os/exec"
  20. "sync"
  21. "time"
  22. )
  23. // Set represents an ipset.
  24. // Set can safely be used concurrently.
  25. type Set struct {
  26. errors chan error
  27. hosts map[string]struct{}
  28. mu sync.Mutex
  29. name string
  30. subscribed bool
  31. // Make these functions fields to allow
  32. // for testing.
  33. add func(string) error
  34. del func(string) error
  35. }
  36. func setExists(name string) (bool, error) {
  37. cmd := exec.Command("ipset", "list", "-n")
  38. var stderr, stdout bytes.Buffer
  39. cmd.Stderr = &stderr
  40. cmd.Stdout = &stdout
  41. if err := cmd.Run(); err != nil {
  42. return false, fmt.Errorf("failed to check for set %s: %s", name, stderr.String())
  43. }
  44. return bytes.Contains(stdout.Bytes(), []byte(name)), nil
  45. }
  46. func hostInSet(set, name string) (bool, error) {
  47. cmd := exec.Command("ipset", "list", set)
  48. var stderr, stdout bytes.Buffer
  49. cmd.Stderr = &stderr
  50. cmd.Stdout = &stdout
  51. if err := cmd.Run(); err != nil {
  52. return false, fmt.Errorf("failed to check for host %s: %s", name, stderr.String())
  53. }
  54. return bytes.Contains(stdout.Bytes(), []byte(name)), nil
  55. }
  56. // New generates a new ipset.
  57. func New(name string) *Set {
  58. return &Set{
  59. errors: make(chan error),
  60. hosts: make(map[string]struct{}),
  61. name: name,
  62. add: func(ip string) error {
  63. ok, err := hostInSet(name, ip)
  64. if err != nil {
  65. return err
  66. }
  67. if !ok {
  68. cmd := exec.Command("ipset", "add", name, ip)
  69. var stderr bytes.Buffer
  70. cmd.Stderr = &stderr
  71. if err := cmd.Run(); err != nil {
  72. return fmt.Errorf("failed to add host %s to set %s: %s", ip, name, stderr.String())
  73. }
  74. }
  75. return nil
  76. },
  77. del: func(ip string) error {
  78. ok, err := hostInSet(name, ip)
  79. if err != nil {
  80. return err
  81. }
  82. if ok {
  83. cmd := exec.Command("ipset", "del", name, ip)
  84. var stderr bytes.Buffer
  85. cmd.Stderr = &stderr
  86. if err := cmd.Run(); err != nil {
  87. return fmt.Errorf("failed to remove host %s from set %s: %s", ip, name, stderr.String())
  88. }
  89. }
  90. return nil
  91. },
  92. }
  93. }
  94. // Run watches for changes to the ipset and reconciles
  95. // the ipset against the desired state.
  96. func (s *Set) Run(stop <-chan struct{}) (<-chan error, error) {
  97. s.mu.Lock()
  98. if s.subscribed {
  99. s.mu.Unlock()
  100. return s.errors, nil
  101. }
  102. // Ensure a given instance only subscribes once.
  103. s.subscribed = true
  104. s.mu.Unlock()
  105. go func() {
  106. defer close(s.errors)
  107. for {
  108. select {
  109. case <-time.After(2 * time.Second):
  110. case <-stop:
  111. return
  112. }
  113. ok, err := setExists(s.name)
  114. if err != nil {
  115. nonBlockingSend(s.errors, err)
  116. }
  117. // The set does not exist so wait and try again later.
  118. if !ok {
  119. continue
  120. }
  121. s.mu.Lock()
  122. for h := range s.hosts {
  123. if err := s.add(h); err != nil {
  124. nonBlockingSend(s.errors, err)
  125. }
  126. }
  127. s.mu.Unlock()
  128. }
  129. }()
  130. return s.errors, nil
  131. }
  132. // CleanUp will clean up any hosts added to the set.
  133. func (s *Set) CleanUp() error {
  134. s.mu.Lock()
  135. defer s.mu.Unlock()
  136. for h := range s.hosts {
  137. if err := s.del(h); err != nil {
  138. return err
  139. }
  140. delete(s.hosts, h)
  141. }
  142. return nil
  143. }
  144. // Set idempotently overwrites any hosts previously defined
  145. // for the ipset with the given hosts.
  146. func (s *Set) Set(hosts []net.IP) error {
  147. h := make(map[string]struct{})
  148. for _, host := range hosts {
  149. if host == nil {
  150. continue
  151. }
  152. h[host.String()] = struct{}{}
  153. }
  154. exists, err := setExists(s.name)
  155. if err != nil {
  156. return err
  157. }
  158. s.mu.Lock()
  159. defer s.mu.Unlock()
  160. for k := range s.hosts {
  161. if _, ok := h[k]; !ok {
  162. if exists {
  163. if err := s.del(k); err != nil {
  164. return err
  165. }
  166. }
  167. delete(s.hosts, k)
  168. }
  169. }
  170. for k := range h {
  171. if _, ok := s.hosts[k]; !ok {
  172. if exists {
  173. if err := s.add(k); err != nil {
  174. return err
  175. }
  176. }
  177. s.hosts[k] = struct{}{}
  178. }
  179. }
  180. return nil
  181. }
  182. func nonBlockingSend(errors chan<- error, err error) {
  183. select {
  184. case errors <- err:
  185. default:
  186. }
  187. }