Browse Source

Lean iptables updates (#324)

* Split iptables rules into append and prepend rules

* fix existing tests

* add iptables tests which include prepend rules

* fix iptables test storage usage

* add test to check reconcile behaviour

* Reconcile prepend rules

* Make usage of `RuleSet` prettier

* Properly implement `Insert` for fake enc

* Make `Rule.Prepend()` behave like `Rule.Append()` regarding uniqueness

* Implement `Insert()` for `metricsClientWrapper`

* Use `iptables.InsertUnique()` instead of `iptables.Insert()`

---------

Co-authored-by: Clive Jevons <clive@jevons-it.net>
Alex Stockinger 3 years ago
parent
commit
12ad2752d2

+ 1 - 1
go.mod

@@ -7,7 +7,7 @@ require (
 	github.com/campoy/embedmd v1.0.0
 	github.com/containernetworking/cni v1.0.1
 	github.com/containernetworking/plugins v1.1.1
-	github.com/coreos/go-iptables v0.6.0
+	github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1
 	github.com/go-kit/kit v0.9.0
 	github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348
 	github.com/metalmatze/signal v0.0.0-20210307161603-1c9aa721a97a

+ 2 - 2
go.sum

@@ -106,8 +106,8 @@ github.com/containernetworking/plugins v1.1.1 h1:+AGfFigZ5TiQH00vhR8qPeSatj53eNG
 github.com/containernetworking/plugins v1.1.1/go.mod h1:Sr5TH/eBsGLXK/h71HeLfX19sZPp3ry5uHSkI4LPxV8=
 github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
 github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
-github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk=
-github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
+github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1 h1:zSiUKnogKeEwIIeUQP/WPH7m0BJ/IvW0VyL4muaauUY=
+github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
 github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
 github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
 github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=

+ 2 - 2
pkg/encapsulation/cilium.go

@@ -96,8 +96,8 @@ func (f *cilium) Init(_ int) error {
 }
 
 // Rules is a no-op.
-func (f *cilium) Rules(_ []*net.IPNet) []iptables.Rule {
-	return nil
+func (f *cilium) Rules(_ []*net.IPNet) iptables.RuleSet {
+	return iptables.RuleSet{}
 }
 
 // Set is a no-op.

+ 1 - 1
pkg/encapsulation/encapsulation.go

@@ -49,7 +49,7 @@ type Encapsulator interface {
 	Gw(net.IP, net.IP, *net.IPNet) net.IP
 	Index() int
 	Init(int) error
-	Rules([]*net.IPNet) []iptables.Rule
+	Rules([]*net.IPNet) iptables.RuleSet
 	Set(*net.IPNet) error
 	Strategy() Strategy
 }

+ 2 - 2
pkg/encapsulation/flannel.go

@@ -95,8 +95,8 @@ func (f *flannel) Init(_ int) error {
 }
 
 // Rules is a no-op.
-func (f *flannel) Rules(_ []*net.IPNet) []iptables.Rule {
-	return nil
+func (f *flannel) Rules(_ []*net.IPNet) iptables.RuleSet {
+	return iptables.RuleSet{}
 }
 
 // Set is a no-op.

+ 9 - 9
pkg/encapsulation/ipip.go

@@ -65,20 +65,20 @@ func (i *ipip) Init(base int) error {
 
 // Rules returns a set of iptables rules that are necessary
 // when traffic between nodes must be encapsulated.
-func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule {
-	var rules []iptables.Rule
+func (i *ipip) Rules(nodes []*net.IPNet) iptables.RuleSet {
+	rules := iptables.RuleSet{}
 	proto := ipipProtocolName()
-	rules = append(rules, iptables.NewIPv4Chain("filter", "KILO-IPIP"))
-	rules = append(rules, iptables.NewIPv6Chain("filter", "KILO-IPIP"))
-	rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
-	rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
+	rules.AddToAppend(iptables.NewIPv4Chain("filter", "KILO-IPIP"))
+	rules.AddToAppend(iptables.NewIPv6Chain("filter", "KILO-IPIP"))
+	rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
+	rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP"))
 	for _, n := range nodes {
 		// Accept encapsulated traffic from peers.
-		rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
+		rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT"))
 	}
 	// Drop all other IPIP traffic.
-	rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
-	rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
+	rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
+	rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP"))
 
 	return rules
 }

+ 2 - 2
pkg/encapsulation/noop.go

@@ -44,8 +44,8 @@ func (n Noop) Init(_ int) error {
 }
 
 // Rules will also do nothing.
-func (n Noop) Rules(_ []*net.IPNet) []iptables.Rule {
-	return nil
+func (n Noop) Rules(_ []*net.IPNet) iptables.RuleSet {
+	return iptables.RuleSet{}
 }
 
 // Set will also do nothing.

+ 18 - 0
pkg/iptables/fake.go

@@ -46,6 +46,24 @@ type fakeClient struct {
 
 var _ Client = &fakeClient{}
 
+func (f *fakeClient) InsertUnique(table, chain string, pos int, spec ...string) error {
+	atomic.AddUint64(&f.calls, 1)
+	exists, err := f.Exists(table, chain, spec...)
+	if err != nil {
+		return err
+	}
+	if exists {
+		return nil
+	}
+	index := pos - 1 // iptables are 1-based
+	rule := &rule{table: table, chain: chain, spec: spec}
+	prefix := append([]Rule{}, f.storage[:index]...)
+	suffix := append([]Rule{}, f.storage[index:]...)
+	prefix = append(prefix, rule)
+	f.storage = append(prefix, suffix...)
+	return nil
+}
+
 func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error {
 	atomic.AddUint64(&f.calls, 1)
 	exists, err := f.Exists(table, chain, spec...)

+ 129 - 19
pkg/iptables/iptables.go

@@ -46,6 +46,11 @@ func ipv6Disabled() (bool, error) {
 // Protocol represents an IP protocol.
 type Protocol byte
 
+type RuleSet struct {
+	appendRules  []Rule // Rules to append to the chain - order matters.
+	prependRules []Rule // Rules to prepend to the chain - order does not matter.
+}
+
 const (
 	// ProtocolIPv4 represents the IPv4 protocol.
 	ProtocolIPv4 Protocol = iota
@@ -53,6 +58,21 @@ const (
 	ProtocolIPv6
 )
 
+func (rs *RuleSet) AddToAppend(rule Rule) {
+	rs.appendRules = append(rs.appendRules, rule)
+}
+
+func (rs *RuleSet) AddToPrepend(rule Rule) {
+	rs.prependRules = append(rs.prependRules, rule)
+}
+
+func (rs *RuleSet) AppendRuleSet(other RuleSet) RuleSet {
+	return RuleSet{
+		appendRules:  append(rs.appendRules, other.appendRules...),
+		prependRules: append(rs.prependRules, other.prependRules...),
+	}
+}
+
 // GetProtocol will return a protocol from the length of an IP address.
 func GetProtocol(ip net.IP) Protocol {
 	if len(ip) == net.IPv4len || ip.To4() != nil {
@@ -64,6 +84,7 @@ func GetProtocol(ip net.IP) Protocol {
 // Client represents any type that can administer iptables rules.
 type Client interface {
 	AppendUnique(table string, chain string, rule ...string) error
+	InsertUnique(table, chain string, pos int, rule ...string) error
 	Delete(table string, chain string, rule ...string) error
 	Exists(table string, chain string, rule ...string) (bool, error)
 	List(table string, chain string) ([]string, error)
@@ -75,7 +96,8 @@ type Client interface {
 
 // Rule is an interface for interacting with iptables objects.
 type Rule interface {
-	Add(Client) error
+	Append(Client) error
+	Prepend(Client) error
 	Delete(Client) error
 	Exists(Client) (bool, error)
 	String() string
@@ -106,7 +128,14 @@ func NewIPv6Rule(table, chain string, spec ...string) Rule {
 	return &rule{table, chain, spec, ProtocolIPv6}
 }
 
-func (r *rule) Add(client Client) error {
+func (r *rule) Prepend(client Client) error {
+	if err := client.InsertUnique(r.table, r.chain, 1, r.spec...); err != nil {
+		return fmt.Errorf("failed to add iptables rule: %v", err)
+	}
+	return nil
+}
+
+func (r *rule) Append(client Client) error {
 	if err := client.AppendUnique(r.table, r.chain, r.spec...); err != nil {
 		return fmt.Errorf("failed to add iptables rule: %v", err)
 	}
@@ -162,7 +191,11 @@ func NewIPv6Chain(table, name string) Rule {
 	return &chain{table, name, ProtocolIPv6}
 }
 
-func (c *chain) Add(client Client) error {
+func (c *chain) Prepend(client Client) error {
+	return c.Append(client)
+}
+
+func (c *chain) Append(client Client) error {
 	// Note: `ClearChain` creates a chain if it does not exist.
 	if err := client.ClearChain(c.table, c.chain); err != nil {
 		return fmt.Errorf("failed to add iptables chain: %v", err)
@@ -224,8 +257,9 @@ type Controller struct {
 	registerer   prometheus.Registerer
 
 	sync.Mutex
-	rules      []Rule
-	subscribed bool
+	appendRules  []Rule
+	prependRules []Rule
+	subscribed   bool
 }
 
 // ControllerOption modifies the controller's configuration.
@@ -333,14 +367,21 @@ func (c *Controller) reconcile() error {
 	c.Lock()
 	defer c.Unlock()
 	var rc ruleCache
-	for i, r := range c.rules {
+	if err := c.reconcileAppendRules(rc); err != nil {
+		return err
+	}
+	return c.reconcilePrependRules(rc)
+}
+
+func (c *Controller) reconcileAppendRules(rc ruleCache) error {
+	for i, r := range c.appendRules {
 		ok, err := rc.exists(c.client(r.Proto()), r)
 		if err != nil {
 			return fmt.Errorf("failed to check if rule exists: %v", err)
 		}
 		if !ok {
-			level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.rules)-i))
-			if err := c.resetFromIndex(i, c.rules); err != nil {
+			level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.appendRules)-i))
+			if err := c.resetFromIndex(i, c.appendRules); err != nil {
 				return fmt.Errorf("failed to add rule: %v", err)
 			}
 			break
@@ -349,6 +390,22 @@ func (c *Controller) reconcile() error {
 	return nil
 }
 
+func (c *Controller) reconcilePrependRules(rc ruleCache) error {
+	for _, r := range c.prependRules {
+		ok, err := rc.exists(c.client(r.Proto()), r)
+		if err != nil {
+			return fmt.Errorf("failed to check if rule exists: %v", err)
+		}
+		if !ok {
+			level.Info(c.logger).Log("msg", "prepending iptables rule")
+			if err := r.Prepend(c.client(r.Proto())); err != nil {
+				return fmt.Errorf("failed to prepend rule: %v", err)
+			}
+		}
+	}
+	return nil
+}
+
 // resetFromIndex re-adds all rules starting from the given index.
 func (c *Controller) resetFromIndex(i int, rules []Rule) error {
 	if i >= len(rules) {
@@ -358,7 +415,7 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error {
 		if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil {
 			return fmt.Errorf("failed to delete rule: %v", err)
 		}
-		if err := rules[j].Add(c.client(rules[j].Proto())); err != nil {
+		if err := rules[j].Append(c.client(rules[j].Proto())); err != nil {
 			return fmt.Errorf("failed to add rule: %v", err)
 		}
 	}
@@ -383,34 +440,87 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error {
 
 // Set idempotently overwrites any iptables rules previously defined
 // for the controller with the given set of rules.
-func (c *Controller) Set(rules []Rule) error {
+func (c *Controller) Set(rules RuleSet) error {
 	c.Lock()
 	defer c.Unlock()
+	if err := c.setAppendRules(rules.appendRules); err != nil {
+		return err
+	}
+	return c.setPrependRules(rules.prependRules)
+}
+
+func (c *Controller) setAppendRules(appendRules []Rule) error {
 	var i int
-	for ; i < len(rules); i++ {
-		if i < len(c.rules) {
-			if rules[i].String() != c.rules[i].String() {
-				if err := c.deleteFromIndex(i, &c.rules); err != nil {
+	for ; i < len(appendRules); i++ {
+		if i < len(c.appendRules) {
+			if appendRules[i].String() != c.appendRules[i].String() {
+				if err := c.deleteFromIndex(i, &c.appendRules); err != nil {
 					return err
 				}
 			}
 		}
-		if i >= len(c.rules) {
-			if err := rules[i].Add(c.client(rules[i].Proto())); err != nil {
+		if i >= len(c.appendRules) {
+			if err := appendRules[i].Append(c.client(appendRules[i].Proto())); err != nil {
 				return fmt.Errorf("failed to add rule: %v", err)
 			}
-			c.rules = append(c.rules, rules[i])
+			c.appendRules = append(c.appendRules, appendRules[i])
 		}
+	}
+	err := c.deleteFromIndex(i, &c.appendRules)
+	if err != nil {
+		return fmt.Errorf("failed to delete rule: %v", err)
+	}
+	return nil
+}
 
+func (c *Controller) setPrependRules(prependRules []Rule) error {
+	for _, prependRule := range prependRules {
+		if !containsRule(c.prependRules, prependRule) {
+			if err := prependRule.Prepend(c.client(prependRule.Proto())); err != nil {
+				return fmt.Errorf("failed to add rule: %v", err)
+			}
+			c.prependRules = append(c.prependRules, prependRule)
+		}
 	}
-	return c.deleteFromIndex(i, &c.rules)
+	for _, existingRule := range c.prependRules {
+		if !containsRule(prependRules, existingRule) {
+			if err := existingRule.Delete(c.client(existingRule.Proto())); err != nil {
+				return fmt.Errorf("failed to delete rule: %v", err)
+			}
+			c.prependRules = removeRule(c.prependRules, existingRule)
+		}
+	}
+	return nil
+}
+
+func removeRule(rules []Rule, toRemove Rule) []Rule {
+	ret := make([]Rule, 0, len(rules))
+	for _, rule := range rules {
+		if rule.String() != toRemove.String() {
+			ret = append(ret, rule)
+		}
+	}
+	return ret
+}
+
+func containsRule(haystack []Rule, needle Rule) bool {
+	for _, element := range haystack {
+		if element.String() == needle.String() {
+			return true
+		}
+	}
+	return false
 }
 
 // CleanUp will clean up any rules created by the controller.
 func (c *Controller) CleanUp() error {
 	c.Lock()
 	defer c.Unlock()
-	return c.deleteFromIndex(0, &c.rules)
+	err := c.deleteFromIndex(0, &c.prependRules)
+	if err != nil {
+		return err
+	}
+	return c.deleteFromIndex(0, &c.appendRules)
 }
 
 func (c *Controller) client(p Protocol) Client {

+ 125 - 46
pkg/iptables/iptables_test.go

@@ -18,70 +18,94 @@ import (
 	"testing"
 )
 
-var rules = []Rule{
+var appendRules = []Rule{
 	NewIPv4Rule("filter", "FORWARD", "-s", "10.4.0.0/16", "-j", "ACCEPT"),
 	NewIPv4Rule("filter", "FORWARD", "-d", "10.4.0.0/16", "-j", "ACCEPT"),
 }
 
+var prependRules = []Rule{
+	NewIPv4Rule("filter", "FORWARD", "-s", "10.5.0.0/16", "-j", "DROP"),
+	NewIPv4Rule("filter", "FORWARD", "-s", "10.6.0.0/16", "-j", "DROP"),
+}
+
 func TestSet(t *testing.T) {
 	for _, tc := range []struct {
-		name    string
-		sets    [][]Rule
-		out     []Rule
-		actions []func(Client) error
+		name       string
+		sets       []RuleSet
+		appendOut  []Rule
+		prependOut []Rule
+		storageOut []Rule
+		actions    []func(Client) error
 	}{
 		{
 			name: "empty",
 		},
 		{
 			name: "single",
-			sets: [][]Rule{
-				{rules[0]},
+			sets: []RuleSet{
+				{appendRules: []Rule{appendRules[0]}},
 			},
-			out: []Rule{rules[0]},
+			appendOut:  []Rule{appendRules[0]},
+			storageOut: []Rule{appendRules[0]},
 		},
 		{
 			name: "two rules",
-			sets: [][]Rule{
-				{rules[0], rules[1]},
+			sets: []RuleSet{
+				{appendRules: []Rule{appendRules[0], appendRules[1]}},
 			},
-			out: []Rule{rules[0], rules[1]},
+			appendOut:  []Rule{appendRules[0], appendRules[1]},
+			storageOut: []Rule{appendRules[0], appendRules[1]},
 		},
 		{
 			name: "multiple",
-			sets: [][]Rule{
-				{rules[0], rules[1]},
-				{rules[1]},
+			sets: []RuleSet{
+				{appendRules: []Rule{appendRules[0], appendRules[1]}},
+				{appendRules: []Rule{appendRules[1]}},
 			},
-			out: []Rule{rules[1]},
+			appendOut:  []Rule{appendRules[1]},
+			storageOut: []Rule{appendRules[1]},
 		},
 		{
 			name: "re-add",
-			sets: [][]Rule{
-				{rules[0], rules[1]},
+			sets: []RuleSet{
+				{appendRules: []Rule{appendRules[0], appendRules[1]}},
 			},
-			out: []Rule{rules[0], rules[1]},
+			appendOut:  []Rule{appendRules[0], appendRules[1]},
+			storageOut: []Rule{appendRules[0], appendRules[1]},
 			actions: []func(c Client) error{
 				func(c Client) error {
-					return rules[0].Delete(c)
+					return appendRules[0].Delete(c)
 				},
 				func(c Client) error {
-					return rules[1].Delete(c)
+					return appendRules[1].Delete(c)
 				},
 			},
 		},
 		{
 			name: "order",
-			sets: [][]Rule{
-				{rules[0], rules[1]},
+			sets: []RuleSet{
+				{appendRules: []Rule{appendRules[0], appendRules[1]}},
 			},
-			out: []Rule{rules[0], rules[1]},
+			appendOut:  []Rule{appendRules[0], appendRules[1]},
+			storageOut: []Rule{appendRules[0], appendRules[1]},
 			actions: []func(c Client) error{
 				func(c Client) error {
-					return rules[0].Delete(c)
+					return appendRules[0].Delete(c)
 				},
 			},
 		},
+		{
+			name: "append and prepend",
+			sets: []RuleSet{
+				{
+					prependRules: []Rule{prependRules[0], prependRules[1]},
+					appendRules:  []Rule{appendRules[0], appendRules[1]},
+				},
+			},
+			appendOut:  []Rule{appendRules[0], appendRules[1]},
+			prependOut: []Rule{prependRules[0], prependRules[1]},
+			storageOut: []Rule{prependRules[1], prependRules[0], appendRules[0], appendRules[1]},
+		},
 	} {
 		client := &fakeClient{}
 		controller, err := New(WithClients(client, client))
@@ -90,7 +114,7 @@ func TestSet(t *testing.T) {
 		}
 		for i := range tc.sets {
 			if err := controller.Set(tc.sets[i]); err != nil {
-				t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err)
+				t.Fatalf("test case %q: got unexpected error setting rule set %d: %v", tc.name, i, err)
 			}
 		}
 		for i, f := range tc.actions {
@@ -101,21 +125,30 @@ func TestSet(t *testing.T) {
 		if err := controller.reconcile(); err != nil {
 			t.Fatalf("test case %q: got unexpected error %v", tc.name, err)
 		}
-		if len(tc.out) != len(client.storage) {
-			t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.out), len(client.storage))
+		if len(tc.storageOut) != len(client.storage) {
+			t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.storageOut), len(client.storage))
 		} else {
-			for i := range tc.out {
-				if tc.out[i].String() != client.storage[i].String() {
-					t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.out[i], client.storage[i])
+			for i := range tc.storageOut {
+				if tc.storageOut[i].String() != client.storage[i].String() {
+					t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.storageOut[i], client.storage[i])
 				}
 			}
 		}
-		if len(tc.out) != len(controller.rules) {
-			t.Errorf("test case %q: expected %d rules in controller, got %d", tc.name, len(tc.out), len(controller.rules))
+		if len(tc.appendOut) != len(controller.appendRules) {
+			t.Errorf("test case %q: expected %d appendRules in controller, got %d", tc.name, len(tc.appendOut), len(controller.appendRules))
 		} else {
-			for i := range tc.out {
-				if tc.out[i].String() != controller.rules[i].String() {
-					t.Errorf("test case %q: expected rule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.out[i], controller.rules[i])
+			for i := range tc.appendOut {
+				if tc.appendOut[i].String() != controller.appendRules[i].String() {
+					t.Errorf("test case %q: expected appendRule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.appendOut[i], controller.appendRules[i])
+				}
+			}
+		}
+		if len(tc.prependOut) != len(controller.prependRules) {
+			t.Errorf("test case %q: expected %d prependRules in controller, got %d", tc.name, len(tc.prependOut), len(controller.prependRules))
+		} else {
+			for i := range tc.prependOut {
+				if tc.prependOut[i].String() != controller.prependRules[i].String() {
+					t.Errorf("test case %q: expected prependRule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.prependOut[i], controller.prependRules[i])
 				}
 			}
 		}
@@ -124,20 +157,26 @@ func TestSet(t *testing.T) {
 
 func TestCleanUp(t *testing.T) {
 	for _, tc := range []struct {
-		name  string
-		rules []Rule
+		name         string
+		appendRules  []Rule
+		prependRules []Rule
 	}{
 		{
-			name:  "empty",
-			rules: nil,
+			name:        "empty",
+			appendRules: nil,
+		},
+		{
+			name:        "single append",
+			appendRules: []Rule{appendRules[0]},
 		},
 		{
-			name:  "single",
-			rules: []Rule{rules[0]},
+			name:        "multiple append",
+			appendRules: []Rule{appendRules[0], appendRules[1]},
 		},
 		{
-			name:  "multiple",
-			rules: []Rule{rules[0], rules[1]},
+			name:         "multiple append and prepend",
+			appendRules:  []Rule{appendRules[0], appendRules[1]},
+			prependRules: []Rule{prependRules[0], prependRules[1]},
 		},
 	} {
 		client := &fakeClient{}
@@ -145,11 +184,12 @@ func TestCleanUp(t *testing.T) {
 		if err != nil {
 			t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err)
 		}
-		if err := controller.Set(tc.rules); err != nil {
+		ruleSet := RuleSet{appendRules: tc.appendRules, prependRules: tc.prependRules}
+		if err := controller.Set(ruleSet); err != nil {
 			t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
 		}
-		if len(client.storage) != len(tc.rules) {
-			t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(tc.rules), len(client.storage))
+		if len(client.storage) != len(tc.appendRules)+len(tc.prependRules) {
+			t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.appendRules)+len(ruleSet.prependRules), len(client.storage))
 		}
 		if err := controller.CleanUp(); err != nil {
 			t.Errorf("test case %q: got unexpected error: %v", tc.name, err)
@@ -159,3 +199,42 @@ func TestCleanUp(t *testing.T) {
 		}
 	}
 }
+
+func TestReconcile(t *testing.T) {
+	for _, tc := range []struct {
+		name         string
+		appendRules  []Rule
+		prependRules []Rule
+		storageOut   []Rule
+	}{
+		{
+			name:         "append and prepend rules",
+			appendRules:  []Rule{appendRules[0], appendRules[1]},
+			prependRules: []Rule{prependRules[0], prependRules[1]},
+			storageOut:   []Rule{prependRules[1], prependRules[0], appendRules[0], appendRules[1]},
+		},
+	} {
+		client := &fakeClient{}
+		controller, err := New(WithClients(client, client))
+		if err != nil {
+			t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err)
+		}
+		controller.appendRules = tc.appendRules
+		controller.prependRules = tc.prependRules
+
+		err = controller.reconcile()
+		if err != nil {
+			t.Fatalf("test case %q: unexpected error during reconcile: %v", tc.name, err)
+		}
+
+		if len(tc.storageOut) != len(client.storage) {
+			t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.storageOut), len(client.storage))
+		} else {
+			for i := range tc.storageOut {
+				if tc.storageOut[i].String() != client.storage[i].String() {
+					t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.storageOut[i], client.storage[i])
+				}
+			}
+		}
+	}
+}

+ 9 - 0
pkg/iptables/metrics.go

@@ -51,6 +51,15 @@ func (m *metricsClientWrapper) AppendUnique(table string, chain string, rule ...
 	return m.client.AppendUnique(table, chain, rule...)
 }
 
+func (m *metricsClientWrapper) InsertUnique(table, chain string, pos int, rule ...string) error {
+	m.operationCounter.With(prometheus.Labels{
+		"operation": "InsertUnique",
+		"table":     table,
+		"chain":     chain,
+	}).Inc()
+	return m.client.InsertUnique(table, chain, pos, rule...)
+}
+
 func (m *metricsClientWrapper) Delete(table string, chain string, rule ...string) error {
 	m.operationCounter.With(prometheus.Labels{
 		"operation": "Delete",

+ 14 - 13
pkg/iptables/rulecache_test.go

@@ -29,21 +29,21 @@ func TestRuleCache(t *testing.T) {
 		{
 			name:  "empty",
 			rules: nil,
-			check: []Rule{rules[0]},
+			check: []Rule{appendRules[0]},
 			out:   []bool{false},
 			calls: 1,
 		},
 		{
 			name:  "single negative",
-			rules: []Rule{rules[1]},
-			check: []Rule{rules[0]},
+			rules: []Rule{appendRules[1]},
+			check: []Rule{appendRules[0]},
 			out:   []bool{false},
 			calls: 1,
 		},
 		{
 			name:  "single positive",
-			rules: []Rule{rules[1]},
-			check: []Rule{rules[1]},
+			rules: []Rule{appendRules[1]},
+			check: []Rule{appendRules[1]},
 			out:   []bool{true},
 			calls: 1,
 		},
@@ -56,29 +56,29 @@ func TestRuleCache(t *testing.T) {
 		},
 		{
 			name:  "rule on chain means chain exists",
-			rules: []Rule{rules[0]},
-			check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}},
+			rules: []Rule{appendRules[0]},
+			check: []Rule{appendRules[0], &chain{"filter", "FORWARD", ProtocolIPv4}},
 			out:   []bool{true, true},
 			calls: 1,
 		},
 		{
 			name:  "rule on chain does not mean table is fully populated",
-			rules: []Rule{rules[0], &chain{"filter", "INPUT", ProtocolIPv4}},
-			check: []Rule{rules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}},
+			rules: []Rule{appendRules[0], &chain{"filter", "INPUT", ProtocolIPv4}},
+			check: []Rule{appendRules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}},
 			out:   []bool{true, false, true},
 			calls: 2,
 		},
 		{
 			name:  "multiple rules on chain",
-			rules: []Rule{rules[0], rules[1]},
-			check: []Rule{rules[0], rules[1], &chain{"filter", "FORWARD", ProtocolIPv4}},
+			rules: []Rule{appendRules[0], appendRules[1]},
+			check: []Rule{appendRules[0], appendRules[1], &chain{"filter", "FORWARD", ProtocolIPv4}},
 			out:   []bool{true, true, true},
 			calls: 1,
 		},
 		{
 			name:  "checking rule on chain does not mean chain exists",
 			rules: nil,
-			check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}},
+			check: []Rule{appendRules[0], &chain{"filter", "FORWARD", ProtocolIPv4}},
 			out:   []bool{false, false},
 			calls: 2,
 		},
@@ -101,7 +101,8 @@ func TestRuleCache(t *testing.T) {
 		client := &fakeClient{}
 		controller.v4 = client
 		controller.v6 = client
-		if err := controller.Set(tc.rules); err != nil {
+		ruleSet := RuleSet{appendRules: tc.rules}
+		if err := controller.Set(ruleSet); err != nil {
 			t.Fatalf("test case %q: Set should not fail: %v", tc.name, err)
 		}
 		// Reset the client's calls so we can examine how many times

+ 2 - 1
pkg/mesh/mesh.go

@@ -526,7 +526,8 @@ func (m *Mesh) applyTopology() {
 			}
 		}
 
-		ipRules = append(m.enc.Rules(cidrs), ipRules...)
+		encIpRules := m.enc.Rules(cidrs)
+		ipRules = encIpRules.AppendRuleSet(ipRules)
 
 		// If we are handling local routes, ensure the local
 		// tunnel has an IP address.

+ 21 - 25
pkg/mesh/routes.go

@@ -311,12 +311,12 @@ func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy,
 }
 
 // Rules returns the iptables rules required by the local node.
-func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
-	var rules []iptables.Rule
-	rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT"))
-	rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT"))
+func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet {
+	rules := iptables.RuleSet{}
+	rules.AddToAppend(iptables.NewIPv4Chain("nat", "KILO-NAT"))
+	rules.AddToAppend(iptables.NewIPv6Chain("nat", "KILO-NAT"))
 	if cni {
-		rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
+		rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT"))
 		// Some linux distros or docker will set forward DROP in the filter table.
 		// To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location.
 		// Leader nodes will forward packets from all nodes within a location because they act as a gateway for them.
@@ -326,55 +326,51 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule {
 				if s.location == t.location {
 					// Make sure packets to and from pod cidrs are not dropped in the forward chain.
 					for _, c := range s.cidrs {
-						rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
-						rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
+						rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT"))
+						rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT"))
 					}
 					// Make sure packets to and from allowed location IPs are not dropped in the forward chain.
 					for _, c := range s.allowedLocationIPs {
-						rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
-						rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
+						rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT"))
+						rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT"))
 					}
 					// Make sure packets to and from private IPs are not dropped in the forward chain.
 					for _, c := range s.privateIPs {
-						rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
-						rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
+						rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
+						rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT"))
 					}
 				}
 			}
 		} else if iptablesForwardRule {
-			rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
-			rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
+			rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT"))
+			rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT"))
 		}
 	}
 	for _, s := range t.segments {
-		rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
+		rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN"))
 		for _, aip := range s.allowedIPs {
-			rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
+			rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN"))
 		}
 		// Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd,
 		// Otherwise packets to these destinations will reach the destination, but never find their way back.
 		// We only want to NAT in locations of the corresponding allowed location IPs.
 		if t.location == s.location {
 			for _, alip := range s.allowedLocationIPs {
-				rules = append(rules,
-					iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
-				)
+				rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"))
 			}
 		}
 	}
 	for _, p := range t.peers {
 		for _, aip := range p.AllowedIPs {
-			rules = append(rules,
-				iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"),
-				iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"),
-			)
+			rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"))
+			rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"))
 		}
 	}
 	for _, s := range t.serviceCIDRs {
-		rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.IP), "nat", "KILO-NAT", "-d", s.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for service CIDRs", "-j", "RETURN"))
+		rules.AddToAppend(iptables.NewRule(iptables.GetProtocol(s.IP), "nat", "KILO-NAT", "-d", s.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for service CIDRs", "-j", "RETURN"))
 	}
-	rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
-	rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
+	rules.AddToAppend(iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
+	rules.AddToAppend(iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE"))
 	return rules
 }
 

+ 29 - 2
vendor/github.com/coreos/go-iptables/iptables/iptables.go

@@ -109,6 +109,7 @@ func Timeout(timeout int) option {
 // For backwards compatibility, by default always uses IPv4 and timeout 0.
 // i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
 // the IPFamily and Timeout options as follow:
+//
 //	ip6t := New(IPFamily(ProtocolIPv6), Timeout(5))
 func New(opts ...option) (*IPTables, error) {
 
@@ -185,6 +186,20 @@ func (ipt *IPTables) Insert(table, chain string, pos int, rulespec ...string) er
 	return ipt.run(cmd...)
 }
 
+// InsertUnique acts like Insert except that it won't insert a duplicate (no matter the position in the chain)
+func (ipt *IPTables) InsertUnique(table, chain string, pos int, rulespec ...string) error {
+	exists, err := ipt.Exists(table, chain, rulespec...)
+	if err != nil {
+		return err
+	}
+
+	if !exists {
+		return ipt.Insert(table, chain, pos, rulespec...)
+	}
+
+	return nil
+}
+
 // Append appends rulespec to specified table/chain
 func (ipt *IPTables) Append(table, chain string, rulespec ...string) error {
 	cmd := append([]string{"-t", table, "-A", chain}, rulespec...)
@@ -219,6 +234,16 @@ func (ipt *IPTables) DeleteIfExists(table, chain string, rulespec ...string) err
 	return err
 }
 
+// List rules in specified table/chain
+func (ipt *IPTables) ListById(table, chain string, id int) (string, error) {
+	args := []string{"-t", table, "-S", chain, strconv.Itoa(id)}
+	rule, err := ipt.executeList(args)
+	if err != nil {
+		return "", err
+	}
+	return rule[0], nil
+}
+
 // List rules in specified table/chain
 func (ipt *IPTables) List(table, chain string) ([]string, error) {
 	args := []string{"-t", table, "-S", chain}
@@ -510,7 +535,9 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
 			syscall.Close(fmu.fd)
 			return err
 		}
-		defer ul.Unlock()
+		defer func() {
+			_ = ul.Unlock()
+		}()
 	}
 
 	var stderr bytes.Buffer
@@ -619,7 +646,7 @@ func iptablesHasWaitCommand(v1 int, v2 int, v3 int) bool {
 	return false
 }
 
-//Checks if an iptablse version is after 1.6.0, when --wait support second
+// Checks if an iptablse version is after 1.6.0, when --wait support second
 func iptablesWaitSupportSecond(v1 int, v2 int, v3 int) bool {
 	if v1 > 1 {
 		return true

+ 1 - 1
vendor/modules.txt

@@ -39,7 +39,7 @@ github.com/containernetworking/plugins/pkg/ns
 github.com/containernetworking/plugins/pkg/utils/sysctl
 github.com/containernetworking/plugins/plugins/ipam/host-local/backend
 github.com/containernetworking/plugins/plugins/ipam/host-local/backend/allocator
-# github.com/coreos/go-iptables v0.6.0
+# github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1
 ## explicit; go 1.16
 github.com/coreos/go-iptables/iptables
 # github.com/davecgh/go-spew v1.1.1