Răsfoiți Sursa

Code and tests to make sure that same security rule can be added twice.

Nuwan Goonasekera 10 ani în urmă
părinte
comite
ff4f79cc79

+ 2 - 1
cloudbridge/cloud/interfaces/resources.py

@@ -1663,7 +1663,8 @@ class SecurityGroup(CloudResource):
     def add_rule(self, ip_protocol=None, from_port=None, to_port=None,
                  cidr_ip=None, src_group=None):
         """
-        Create a security group rule.
+        Create a security group rule. If the rule already exists, simply
+        returns it.
 
         You need to pass in either ``src_group`` OR ``ip_protocol``,
         ``from_port``, ``to_port``, and ``cidr_ip``. In other words, either

+ 9 - 2
cloudbridge/cloud/providers/aws/resources.py

@@ -17,6 +17,7 @@ from cloudbridge.cloud.base.resources import BaseSnapshot
 from cloudbridge.cloud.base.resources import BaseSubnet
 from cloudbridge.cloud.base.resources import BaseVolume
 from cloudbridge.cloud.base.resources import ClientPagedResultList
+from cloudbridge.cloud.interfaces.resources import SecurityGroup
 from cloudbridge.cloud.interfaces.resources import InstanceState
 from cloudbridge.cloud.interfaces.resources import MachineImageState
 from cloudbridge.cloud.interfaces.resources import NetworkState
@@ -632,13 +633,18 @@ class AWSSecurityGroup(BaseSecurityGroup):
         :return: Rule object if successful or ``None``.
         """
         try:
+            if not isinstance(src_group, SecurityGroup):
+                src_group = self._provider.security.security_groups.get(
+                                src_group)
+
             if self._security_group.authorize(
                     ip_protocol=ip_protocol,
                     from_port=from_port,
                     to_port=to_port,
                     cidr_ip=cidr_ip,
                     # pylint:disable=protected-access
-                    src_group=src_group._security_group if src_group else None):
+                    src_group=src_group._security_group if src_group
+                    else None):
                 return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
                                      src_group)
         except EC2ResponseError as ec2e:
@@ -765,7 +771,8 @@ class AWSBucketObject(BaseBucketObject):
         """
         Get the date and time this object was last modified.
         """
-        lm = datetime.strptime(self._key.last_modified, "%Y-%m-%dT%H:%M:%S.%fZ")
+        lm = datetime.strptime(self._key.last_modified,
+                               "%Y-%m-%dT%H:%M:%S.%fZ")
         return lm.strftime("%Y-%m-%dT%H:%M:%S.%f")
 
     def iter_content(self):

+ 18 - 0
cloudbridge/cloud/providers/openstack/resources.py

@@ -21,6 +21,7 @@ from cloudbridge.cloud.interfaces.resources import MachineImageState
 from cloudbridge.cloud.interfaces.resources import NetworkState
 from cloudbridge.cloud.interfaces.resources import SnapshotState
 from cloudbridge.cloud.interfaces.resources import VolumeState
+from cloudbridge.cloud.interfaces.resources import SecurityGroup
 from cloudbridge.cloud.providers.openstack import helpers as oshelpers
 import inspect
 import json
@@ -787,7 +788,17 @@ class OpenStackSecurityGroup(BaseSecurityGroup):
         :return: Rule object if successful or ``None``.
         """
         if src_group:
+            if not isinstance(src_group, SecurityGroup):
+                src_group = self._provider.security.security_groups.get(
+                                src_group)
             for protocol in ['udp', 'tcp']:
+                existing_rule = self.get_rule(ip_protocol=ip_protocol,
+                                              from_port=1,
+                                              to_port=65535,
+                                              src_group=src_group)
+                if existing_rule:
+                    return existing_rule
+
                 rule = self._provider.nova.security_group_rules.create(
                     parent_group_id=self._security_group.id,
                     ip_protocol=protocol,
@@ -800,6 +811,13 @@ class OpenStackSecurityGroup(BaseSecurityGroup):
                 return OpenStackSecurityGroupRule(self._provider,
                                                   rule.to_dict(), self)
         else:
+            existing_rule = self.get_rule(ip_protocol=ip_protocol,
+                                          from_port=from_port,
+                                          to_port=to_port,
+                                          cidr_ip=cidr_ip)
+            if existing_rule:
+                return existing_rule
+
             rule = self._provider.nova.security_group_rules.create(
                 parent_group_id=self._security_group.id,
                 ip_protocol=ip_protocol,

+ 19 - 2
test/test_security_service.py

@@ -48,8 +48,9 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                 "Get key pair did not return the expected key {0}."
                 .format(name))
 
+            # Recreating existing keypair should raise an exception
             with self.assertRaises(Exception):
-                recreated_kp = self.provider.security.key_pairs.create(name=name)
+                self.provider.security.key_pairs.create(name=name)
         kpl = self.provider.security.key_pairs.list()
         found_kp = [k for k in kpl if k.name == name]
         self.assertTrue(
@@ -188,7 +189,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                 sort_keys=True)
             self.assertTrue(
                 sg.to_json() == json_repr,
-                "JSON sec group representation {0}\n does not match expected {1}"
+                "JSON sec group representation {0} does not match expected {1}"
                 .format(sg.to_json(), json_repr))
 
         sgl = self.provider.security.security_groups.list()
@@ -198,6 +199,22 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             "Security group {0} should have been deleted but still exists."
             .format(name))
 
+    def test_security_group_rule_add_twice(self):
+        """Test whether adding the same rule twice succeeds."""
+        name = 'cbtestsecuritygroupB-{0}'.format(uuid.uuid4())
+        sg = self.provider.security.security_groups.create(
+            name=name, description=name)
+        with helpers.cleanup_action(lambda: sg.delete()):
+            rule = sg.add_rule(ip_protocol='tcp', from_port=1111, to_port=1111,
+                               cidr_ip='0.0.0.0/0')
+            # attempting to add the same rule twice should succeed
+            same_rule = sg.add_rule(ip_protocol='tcp', from_port=1111,
+                                    to_port=1111, cidr_ip='0.0.0.0/0')
+            self.assertTrue(
+                rule == same_rule,
+                "Expected rule {0} not found in security group: {0}".format(
+                    same_rule, sg.rules))
+
     def test_security_group_group_rule(self):
         """Test for proper creation of a security group rule."""
         name = 'cbtestsecuritygroupC-{0}'.format(uuid.uuid4())