Bladeren bron

Merge pull request #116 from squat/reduce_iptables_calls

pkg/iptables: reduce calls to iptables
Lucas Servén Marín 5 jaren geleden
bovenliggende
commit
c060bf24e2
7 gewijzigde bestanden met toevoegingen van 353 en 32 verwijderingen
  1. 41 1
      pkg/iptables/fake.go
  2. 67 19
      pkg/iptables/iptables.go
  3. 8 6
      pkg/iptables/iptables_test.go
  4. 106 0
      pkg/iptables/rulecache.go
  5. 125 0
      pkg/iptables/rulecache_test.go
  6. 1 1
      pkg/mesh/mesh.go
  7. 5 5
      pkg/mesh/routes.go

+ 41 - 1
pkg/iptables/fake.go

@@ -16,6 +16,8 @@ package iptables
 
 import (
 	"fmt"
+	"strings"
+	"sync/atomic"
 
 	"github.com/coreos/go-iptables/iptables"
 )
@@ -38,12 +40,14 @@ func (s statusError) ExitStatus() int {
 }
 
 type fakeClient struct {
+	calls   uint64
 	storage []Rule
 }
 
 var _ Client = &fakeClient{}
 
 func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error {
+	atomic.AddUint64(&f.calls, 1)
 	exists, err := f.Exists(table, chain, spec...)
 	if err != nil {
 		return err
@@ -56,6 +60,7 @@ func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error {
 }
 
 func (f *fakeClient) Delete(table, chain string, spec ...string) error {
+	atomic.AddUint64(&f.calls, 1)
 	r := &rule{table: table, chain: chain, spec: spec}
 	for i := range f.storage {
 		if f.storage[i].String() == r.String() {
@@ -69,6 +74,7 @@ func (f *fakeClient) Delete(table, chain string, spec ...string) error {
 }
 
 func (f *fakeClient) Exists(table, chain string, spec ...string) (bool, error) {
+	atomic.AddUint64(&f.calls, 1)
 	r := &rule{table: table, chain: chain, spec: spec}
 	for i := range f.storage {
 		if f.storage[i].String() == r.String() {
@@ -78,7 +84,22 @@ func (f *fakeClient) Exists(table, chain string, spec ...string) (bool, error) {
 	return false, nil
 }
 
+func (f *fakeClient) List(table, chain string) ([]string, error) {
+	atomic.AddUint64(&f.calls, 1)
+	var rs []string
+	for i := range f.storage {
+		switch r := f.storage[i].(type) {
+		case *rule:
+			if r.table == table && r.chain == chain {
+				rs = append(rs, strings.TrimSpace(strings.TrimPrefix(r.String(), table)))
+			}
+		}
+	}
+	return rs, nil
+}
+
 func (f *fakeClient) ClearChain(table, name string) error {
+	atomic.AddUint64(&f.calls, 1)
 	for i := range f.storage {
 		r, ok := f.storage[i].(*rule)
 		if !ok {
@@ -90,10 +111,14 @@ func (f *fakeClient) ClearChain(table, name string) error {
 			}
 		}
 	}
-	return f.DeleteChain(table, name)
+	if err := f.DeleteChain(table, name); err != nil {
+		return err
+	}
+	return f.NewChain(table, name)
 }
 
 func (f *fakeClient) DeleteChain(table, name string) error {
+	atomic.AddUint64(&f.calls, 1)
 	for i := range f.storage {
 		r, ok := f.storage[i].(*rule)
 		if !ok {
@@ -116,6 +141,7 @@ func (f *fakeClient) DeleteChain(table, name string) error {
 }
 
 func (f *fakeClient) NewChain(table, name string) error {
+	atomic.AddUint64(&f.calls, 1)
 	c := &chain{table: table, chain: name}
 	for i := range f.storage {
 		if f.storage[i].String() == c.String() {
@@ -125,3 +151,17 @@ func (f *fakeClient) NewChain(table, name string) error {
 	f.storage = append(f.storage, c)
 	return nil
 }
+
+func (f *fakeClient) ListChains(table string) ([]string, error) {
+	atomic.AddUint64(&f.calls, 1)
+	var cs []string
+	for i := range f.storage {
+		switch c := f.storage[i].(type) {
+		case *chain:
+			if c.table == table {
+				cs = append(cs, c.chain)
+			}
+		}
+	}
+	return cs, nil
+}

+ 67 - 19
pkg/iptables/iptables.go

@@ -17,11 +17,12 @@ package iptables
 import (
 	"fmt"
 	"net"
-	"strings"
 	"sync"
 	"time"
 
 	"github.com/coreos/go-iptables/iptables"
+	"github.com/go-kit/kit/log"
+	"github.com/go-kit/kit/log/level"
 )
 
 // Protocol represents an IP protocol.
@@ -47,9 +48,11 @@ type Client interface {
 	AppendUnique(table string, chain string, rule ...string) error
 	Delete(table string, chain string, rule ...string) error
 	Exists(table string, chain string, rule ...string) (bool, error)
+	List(table string, chain string) ([]string, error)
 	ClearChain(table string, chain string) error
 	DeleteChain(table string, chain string) error
 	NewChain(table string, chain string) error
+	ListChains(table string) ([]string, error)
 }
 
 // Rule is an interface for interacting with iptables objects.
@@ -107,7 +110,17 @@ func (r *rule) String() string {
 	if r == nil {
 		return ""
 	}
-	return fmt.Sprintf("%s_%s_%s", r.table, r.chain, strings.Join(r.spec, "_"))
+	spec := r.table + " -A " + r.chain
+	for i, s := range r.spec {
+		spec += " "
+		// If this is the content of a comment, wrap the value in quotes.
+		if i > 0 && r.spec[i-1] == "--comment" {
+			spec += `"` + s + `"`
+		} else {
+			spec += s
+		}
+	}
+	return spec
 }
 
 func (r *rule) Proto() Protocol {
@@ -132,6 +145,7 @@ func NewIPv6Chain(table, name string) Rule {
 }
 
 func (c *chain) Add(client Client) error {
+	// Note: `ClearChain` creates a chain if it does not exist.
 	if err := client.ClearChain(c.table, c.chain); err != nil {
 		return fmt.Errorf("failed to add iptables chain: %v", err)
 	}
@@ -171,41 +185,73 @@ func (c *chain) String() string {
 	if c == nil {
 		return ""
 	}
-	return fmt.Sprintf("%s_%s", c.table, c.chain)
+	return chainToString(c.table, c.chain)
 }
 
 func (c *chain) Proto() Protocol {
 	return c.proto
 }
 
+func chainToString(table, chain string) string {
+	return fmt.Sprintf("%s -N %s", table, chain)
+}
+
 // Controller is able to reconcile a given set of iptables rules.
 type Controller struct {
 	v4     Client
 	v6     Client
 	errors chan error
+	logger log.Logger
 
 	sync.Mutex
 	rules      []Rule
 	subscribed bool
 }
 
-// New generates a new iptables rules controller.
-// It expects an IP address length to determine
-// whether to operate in IPv4 or IPv6 mode.
-func New() (*Controller, error) {
-	v4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
-	if err != nil {
-		return nil, fmt.Errorf("failed to create iptables IPv4 client: %v", err)
+// ControllerOption modifies the controller's configuration.
+type ControllerOption func(h *Controller)
+
+// WithLogger adds a logger to the controller.
+func WithLogger(logger log.Logger) ControllerOption {
+	return func(c *Controller) {
+		c.logger = logger
 	}
-	v6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
-	if err != nil {
-		return nil, fmt.Errorf("failed to create iptables IPv6 client: %v", err)
+}
+
+// WithClients adds iptables clients to the controller.
+func WithClients(v4, v6 Client) ControllerOption {
+	return func(c *Controller) {
+		c.v4 = v4
+		c.v6 = v6
 	}
-	return &Controller{
-		v4:     v4,
-		v6:     v6,
+}
+
+// New generates a new iptables rules controller.
+// If no options are given, IPv4 and IPv6 clients
+// will be instantiated using the regular iptables backend.
+func New(opts ...ControllerOption) (*Controller, error) {
+	c := &Controller{
 		errors: make(chan error),
-	}, nil
+		logger: log.NewNopLogger(),
+	}
+	for _, o := range opts {
+		o(c)
+	}
+	if c.v4 == nil {
+		v4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
+		if err != nil {
+			return nil, fmt.Errorf("failed to create iptables IPv4 client: %v", err)
+		}
+		c.v4 = v4
+	}
+	if c.v6 == nil {
+		v6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
+		if err != nil {
+			return nil, fmt.Errorf("failed to create iptables IPv6 client: %v", err)
+		}
+		c.v6 = v6
+	}
+	return c, nil
 }
 
 // Run watches for changes to iptables rules and reconciles
@@ -223,7 +269,7 @@ func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) {
 		defer close(c.errors)
 		for {
 			select {
-			case <-time.After(5 * time.Second):
+			case <-time.After(30 * time.Second):
 			case <-stop:
 				return
 			}
@@ -242,12 +288,14 @@ func (c *Controller) Run(stop <-chan struct{}) (<-chan error, error) {
 func (c *Controller) reconcile() error {
 	c.Lock()
 	defer c.Unlock()
+	var rc ruleCache
 	for i, r := range c.rules {
-		ok, err := r.Exists(c.client(r.Proto()))
+		ok, err := rc.exists(c.client(r.Proto()), r)
 		if err != nil {
 			return fmt.Errorf("failed to check if rule exists: %v", err)
 		}
 		if !ok {
+			level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.rules)-i))
 			if err := c.resetFromIndex(i, c.rules); err != nil {
 				return fmt.Errorf("failed to add rule: %v", err)
 			}

+ 8 - 6
pkg/iptables/iptables_test.go

@@ -83,10 +83,11 @@ func TestSet(t *testing.T) {
 			},
 		},
 	} {
-		controller := &Controller{}
 		client := &fakeClient{}
-		controller.v4 = client
-		controller.v6 = client
+		controller, err := New(WithClients(client, client))
+		if err != nil {
+			t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err)
+		}
 		for i := range tc.sets {
 			if err := controller.Set(tc.sets[i]); err != nil {
 				t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err)
@@ -139,10 +140,11 @@ func TestCleanUp(t *testing.T) {
 			rules: []Rule{rules[0], rules[1]},
 		},
 	} {
-		controller := &Controller{}
 		client := &fakeClient{}
-		controller.v4 = client
-		controller.v6 = client
+		controller, err := New(WithClients(client, client))
+		if err != nil {
+			t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err)
+		}
 		if err := controller.Set(tc.rules); err != nil {
 			t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
 		}

+ 106 - 0
pkg/iptables/rulecache.go

@@ -0,0 +1,106 @@
+// Copyright 2021 the Kilo authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+	"fmt"
+	"strings"
+)
+
+type ruleCacheFlag byte
+
+const (
+	exists ruleCacheFlag = 1 << iota
+	populated
+)
+
+type isNotExistError interface {
+	error
+	IsNotExist() bool
+}
+
+// ruleCache is a lazy cache that can be used to
+// check if a given rule or chain exists in an iptables
+// table.
+type ruleCache [2]map[string]ruleCacheFlag
+
+func (rc *ruleCache) populateTable(c Client, proto Protocol, table string) error {
+	// If the table already exists in the destination map,
+	// exit early since it has already been populated.
+	if rc[proto][table]&populated != 0 {
+		return nil
+	}
+	cs, err := c.ListChains(table)
+	if err != nil {
+		return fmt.Errorf("failed to populate chains for table %q: %v", table, err)
+	}
+	rc[proto][table] = exists | populated
+	for i := range cs {
+		rc[proto][chainToString(table, cs[i])] |= exists
+	}
+	return nil
+}
+
+func (rc *ruleCache) populateChain(c Client, proto Protocol, table, chain string) error {
+	// If the destination chain true, then it has already been populated.
+	if rc[proto][chainToString(table, chain)]&populated != 0 {
+		return nil
+	}
+	rs, err := c.List(table, chain)
+	if err != nil {
+		if existsErr, ok := err.(isNotExistError); ok && existsErr.IsNotExist() {
+			rc[proto][chainToString(table, chain)] = populated
+			return nil
+		}
+		return fmt.Errorf("failed to populate rules in chain %q for table %q: %v", chain, table, err)
+	}
+	for i := range rs {
+		rc[proto][strings.Join([]string{table, rs[i]}, " ")] = exists
+	}
+	// If there are rules on the chain, then the chain exists too.
+	if len(rs) > 0 {
+		rc[proto][chainToString(table, chain)] = exists
+	}
+	rc[proto][chainToString(table, chain)] |= populated
+	return nil
+}
+
+func (rc *ruleCache) populateRules(c Client, r Rule) error {
+	// Ensure a map for the proto exists.
+	if rc[r.Proto()] == nil {
+		rc[r.Proto()] = make(map[string]ruleCacheFlag)
+	}
+
+	if ch, ok := r.(*chain); ok {
+		return rc.populateTable(c, r.Proto(), ch.table)
+	}
+
+	ru := r.(*rule)
+	return rc.populateChain(c, r.Proto(), ru.table, ru.chain)
+}
+
+func (rc *ruleCache) exists(c Client, r Rule) (bool, error) {
+	// Exit early if the exact rule exists by name.
+	if rc[r.Proto()][r.String()]&exists != 0 {
+		return true, nil
+	}
+
+	// Otherwise, populate the respective rules.
+	if err := rc.populateRules(c, r); err != nil {
+		return false, err
+	}
+
+	return rc[r.Proto()][r.String()]&exists != 0, nil
+}

+ 125 - 0
pkg/iptables/rulecache_test.go

@@ -0,0 +1,125 @@
+// Copyright 2021 the Kilo authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+	"testing"
+)
+
+func TestRuleCache(t *testing.T) {
+	for _, tc := range []struct {
+		name  string
+		rules []Rule
+		check []Rule
+		out   []bool
+		calls uint64
+	}{
+		{
+			name:  "empty",
+			rules: nil,
+			check: []Rule{rules[0]},
+			out:   []bool{false},
+			calls: 1,
+		},
+		{
+			name:  "single negative",
+			rules: []Rule{rules[1]},
+			check: []Rule{rules[0]},
+			out:   []bool{false},
+			calls: 1,
+		},
+		{
+			name:  "single positive",
+			rules: []Rule{rules[1]},
+			check: []Rule{rules[1]},
+			out:   []bool{true},
+			calls: 1,
+		},
+		{
+			name:  "single chain",
+			rules: []Rule{&chain{"nat", "KILO-NAT", ProtocolIPv4}},
+			check: []Rule{&chain{"nat", "KILO-NAT", ProtocolIPv4}},
+			out:   []bool{true},
+			calls: 1,
+		},
+		{
+			name:  "rule on chain means chain exists",
+			rules: []Rule{rules[0]},
+			check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}},
+			out:   []bool{true, true},
+			calls: 1,
+		},
+		{
+			name:  "rule on chain does not mean table is fully populated",
+			rules: []Rule{rules[0], &chain{"filter", "INPUT", ProtocolIPv4}},
+			check: []Rule{rules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}},
+			out:   []bool{true, false, true},
+			calls: 2,
+		},
+		{
+			name:  "multiple rules on chain",
+			rules: []Rule{rules[0], rules[1]},
+			check: []Rule{rules[0], rules[1], &chain{"filter", "FORWARD", ProtocolIPv4}},
+			out:   []bool{true, true, true},
+			calls: 1,
+		},
+		{
+			name:  "checking rule on chain does not mean chain exists",
+			rules: nil,
+			check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}},
+			out:   []bool{false, false},
+			calls: 2,
+		},
+		{
+			name:  "multiple chains on same table",
+			rules: nil,
+			check: []Rule{&chain{"filter", "INPUT", ProtocolIPv4}, &chain{"filter", "FORWARD", ProtocolIPv4}},
+			out:   []bool{false, false},
+			calls: 1,
+		},
+		{
+			name:  "multiple chains on different table",
+			rules: nil,
+			check: []Rule{&chain{"filter", "INPUT", ProtocolIPv4}, &chain{"nat", "POSTROUTING", ProtocolIPv4}},
+			out:   []bool{false, false},
+			calls: 2,
+		},
+	} {
+		controller := &Controller{}
+		client := &fakeClient{}
+		controller.v4 = client
+		controller.v6 = client
+		if err := controller.Set(tc.rules); err != nil {
+			t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
+		}
+		// Reset the client's calls so we can examine how many times
+		// the rule cache performs operations.
+		client.calls = 0
+		var rc ruleCache
+		for i := range tc.check {
+			ok, err := rc.exists(controller.client(tc.check[i].Proto()), tc.check[i])
+			if err != nil {
+				t.Fatalf("test case %q check %d: check should not fail: %v", tc.name, i, err)
+			}
+			if ok != tc.out[i] {
+				t.Errorf("test case %q check %d: expected %t, got %t", tc.name, i, tc.out[i], ok)
+			}
+		}
+		if client.calls != tc.calls {
+			t.Errorf("test case %q: expected client to be called %d times, got %d", tc.name, tc.calls, client.calls)
+		}
+	}
+
+}

+ 1 - 1
pkg/mesh/mesh.go

@@ -143,7 +143,7 @@ func New(backend Backend, enc encapsulation.Encapsulator, granularity Granularit
 		level.Debug(logger).Log("msg", "running without a private IP address")
 	}
 	level.Debug(logger).Log("msg", fmt.Sprintf("using %s as the public IP address", publicIP.String()))
-	ipTables, err := iptables.New()
+	ipTables, err := iptables.New(iptables.WithLogger(log.With(logger, "component", "iptables")))
 	if err != nil {
 		return nil, fmt.Errorf("failed to IP tables controller: %v", err)
 	}

+ 5 - 5
pkg/mesh/routes.go

@@ -225,19 +225,19 @@ func (t *Topology) Rules(cni bool) []iptables.Rule {
 	rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
 	rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
 	if cni {
-		rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-s", t.subnet.String(), "-j", "KILO-NAT"))
+		rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(t.subnet.IP)), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
 	}
 	for _, s := range t.segments {
-		rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-d", s.wireGuardIP.String(), "-j", "RETURN"))
+		rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(s.wireGuardIP)), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
 		for _, aip := range s.allowedIPs {
-			rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-d", aip.String(), "-j", "RETURN"))
+			rules = append(rules, iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
 		}
 	}
 	for _, p := range t.peers {
 		for _, aip := range p.AllowedIPs {
 			rules = append(rules,
-				iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-s", aip.String(), "-j", "KILO-NAT"),
-				iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-d", aip.String(), "-j", "RETURN"),
+				iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
+				iptables.NewRule(iptables.GetProtocol(len(aip.IP)), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
 			)
 		}
 	}