Quellcode durchsuchen

Make SecurityGroup add_rule method return a SecurityGroupRule object; add get_rule method to the SecurityGroup class; test

Enis Afgan vor 10 Jahren
Ursprung
Commit
16d3da63ef

+ 34 - 3
cloudbridge/cloud/interfaces/resources.py

@@ -1666,7 +1666,7 @@ class SecurityGroup(CloudResource):
         Create a security group rule.
 
         You need to pass in either ``src_group`` OR ``ip_protocol``,
-        ``from_port``, ``to_port``, and ``cidr_ip``.  In other words, either
+        ``from_port``, ``to_port``, and ``cidr_ip``. In other words, either
         you are authorizing another group or you are authorizing some
         ip-based rule.
 
@@ -1685,8 +1685,39 @@ class SecurityGroup(CloudResource):
         :type src_group: :class:``.SecurityGroup``
         :param src_group: The Security Group object you are granting access to.
 
-        :rtype: bool
-        :return: ``True`` if successful.
+        :rtype: :class:``.SecurityGroupRule``
+        :return: Rule object if successful or ``None``.
+        """
+        pass
+
+    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
+                 cidr_ip=None, src_group=None):
+        """
+        Get a security group rule with the specified parameters.
+
+        You need to pass in either ``src_group`` OR ``ip_protocol``,
+        ``from_port``, ``to_port``, and ``cidr_ip``. Note that when retrieving
+        a group rule, this method will return only one rule although possibly
+        several rules exist for the group rule. In that case, use the
+        ``.rules`` property and filter the results as desired.
+
+        :type ip_protocol: str
+        :param ip_protocol: Either ``tcp`` | ``udp`` | ``icmp``.
+
+        :type from_port: int
+        :param from_port: The beginning port number you are enabling.
+
+        :type to_port: int
+        :param to_port: The ending port number you are enabling.
+
+        :type cidr_ip: str or list of strings
+        :param cidr_ip: The CIDR block you are providing access to.
+
+        :type src_group: :class:``.SecurityGroup``
+        :param src_group: The Security Group object you are granting access to.
+
+        :rtype: :class:``.SecurityGroupRule``
+        :return: Role object if one can be found or ``None``.
         """
         pass
 

+ 28 - 10
cloudbridge/cloud/providers/aws/resources.py

@@ -599,6 +599,9 @@ class AWSSecurityGroup(BaseSecurityGroup):
 
     @property
     def rules(self):
+        # Refresh the local object; otherwise get stale rules
+        self._security_group = self._provider.ec2_conn.get_all_security_groups(
+            group_ids=[self.id])[0]
         return [AWSSecurityGroupRule(self._provider, r, self)
                 for r in self._security_group.rules]
 
@@ -627,16 +630,31 @@ class AWSSecurityGroup(BaseSecurityGroup):
         :type src_group: ``object`` of :class:`.SecurityGroup`
         :param src_group: The Security Group you are granting access to.
 
-        :rtype: bool
-        :return: True if successful.
-        """
-        return self._security_group.authorize(
-            ip_protocol=ip_protocol,
-            from_port=from_port,
-            to_port=to_port,
-            cidr_ip=cidr_ip,
-            # pylint:disable=protected-access
-            src_group=src_group._security_group if src_group else None)
+        :rtype: :class:``.SecurityGroupRule``
+        :return: Rule object if successful or ``None``.
+        """
+        if self._security_group.authorize(
+                ip_protocol=ip_protocol,
+                from_port=from_port,
+                to_port=to_port,
+                cidr_ip=cidr_ip,
+                # pylint:disable=protected-access
+                src_group=src_group._security_group if src_group else None):
+            return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
+                                 src_group)
+        return None
+
+    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
+                 cidr_ip=None, src_group=None):
+        for rule in self._security_group.rules:
+            if (rule.ip_protocol == ip_protocol and
+               str(rule.from_port) == str(from_port) and
+               str(rule.to_port) == str(to_port) and
+               rule.grants[0].cidr_ip == cidr_ip) or \
+               (rule.grants[0].name == src_group.name if src_group and
+               hasattr(rule.grants[0], 'name') else False):
+                return AWSSecurityGroupRule(self._provider, rule, self)
+        return None
 
     def to_json(self):
         attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))

+ 36 - 17
cloudbridge/cloud/providers/openstack/resources.py

@@ -781,26 +781,45 @@ class OpenStackSecurityGroup(BaseSecurityGroup):
         :type src_group: ``object`` of :class:`.SecurityGroup`
         :param src_group: The Security Group you are granting access to.
 
-        :rtype: bool
-        :return: True if successful.
+        :rtype: :class:``.SecurityGroupRule``
+        :return: Rule object if successful or ``None``.
         """
         if src_group:
-            for protocol in ['tcp', 'udp']:
-                if self._provider.nova.security_group_rules.create(
-                   parent_group_id=self._security_group.id,
-                   ip_protocol=protocol,
-                   from_port=1,
-                   to_port=65535,
-                   group_id=src_group.id):
-                    return True
+            for protocol in ['udp', 'tcp']:
+                rule = self._provider.nova.security_group_rules.create(
+                    parent_group_id=self._security_group.id,
+                    ip_protocol=protocol,
+                    from_port=1,
+                    to_port=65535,
+                    group_id=src_group.id)
+            if rule:
+                # We can only return one Rule so default to TCP (ie, last in
+                # the for loop above).
+                return OpenStackSecurityGroupRule(self._provider,
+                                                  rule.to_dict(), self)
         else:
-            if self._provider.nova.security_group_rules.create(
-               parent_group_id=self._security_group.id,
-               ip_protocol=ip_protocol,
-               from_port=from_port,
-               to_port=to_port,
-               cidr=cidr_ip):
-                return True
+            rule = self._provider.nova.security_group_rules.create(
+                parent_group_id=self._security_group.id,
+                ip_protocol=ip_protocol,
+                from_port=from_port,
+                to_port=to_port,
+                cidr=cidr_ip)
+            if rule:
+                return OpenStackSecurityGroupRule(self._provider,
+                                                  rule.to_dict(), self)
+        return None
+
+    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
+                 cidr_ip=None, src_group=None):
+        for rule in self._security_group.rules:
+            if (rule['ip_protocol'] == ip_protocol and
+               str(rule['from_port']) == str(from_port) and
+               str(rule['to_port']) == str(to_port) and
+               rule['ip_range'].get('cidr') == cidr_ip) or \
+               (rule['group'].get('name') == src_group.name if src_group
+               else False):
+                return OpenStackSecurityGroupRule(self._provider, rule, self)
+        return None
 
     def to_json(self):
         attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))

+ 14 - 14
test/test_security_service.py

@@ -161,14 +161,15 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
         with helpers.cleanup_action(lambda: sg.delete()):
             sg.add_rule(ip_protocol='tcp', from_port=1111, to_port=1111,
                         cidr_ip='0.0.0.0/0')
-            found_rules = [rule for rule in sg.rules if
-                           rule.cidr_ip == '0.0.0.0/0' and
-                           rule.ip_protocol == 'tcp' and
-                           rule.from_port == 1111 and
-                           rule.to_port == 1111]
+            rule = sg.get_rule(ip_protocol='tcp', from_port=1111, to_port=1111,
+                               cidr_ip='0.0.0.0/0')
             self.assertTrue(
-                len(found_rules) == 1,
-                "Expected rule not found in security group: {0}".format(name))
+                (rule.ip_protocol == 'tcp' and
+                 rule.from_port == 1111 and
+                 rule.to_port == 1111 and
+                 rule.cidr_ip == '0.0.0.0/0'),
+                "Expected rule {0} not found in security group: {0}".format(
+                    rule, sg.rules))
 
             object_keys = (
                 sg.rules[0].ip_protocol,
@@ -187,8 +188,8 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                 "The same security groups should still be equal?")
             json_repr = json.dumps(
                 {"description": name, "name": name, "id": sg.id, "rules":
-                 [{"from_port": 1111, "group": "", "cidr_ip": "0.0.0.0/0",
-                   "parent": sg.id, "to_port": 1111, "ip_protocol": "tcp",
+                 [{"from_port": "1111", "group": "", "cidr_ip": "0.0.0.0/0",
+                   "parent": sg.id, "to_port": "1111", "ip_protocol": "tcp",
                    "id": sg.rules[0].id}]},
                 sort_keys=True)
             self.assertTrue(
@@ -213,13 +214,12 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                 len(sg.rules) == 0,
                 "Expected no security group group rule. Got {0}."
                 .format(sg.rules))
-            sg.add_rule(src_group=sg)
+            rule = sg.add_rule(src_group=sg)
             self.assertTrue(
-                sg.rules[0].group.name == name,
+                rule.group.name == name,
                 "Expected security group rule name {0}. Got {1}."
-                .format(name, sg.rules[0].group.name))
-            sg.rules[0].delete()
-            sg = self.provider.security.security_groups.get(sg.id)  # update
+                .format(name, rule.group.name))
+            rule.delete()
             self.assertTrue(
                 len(sg.rules) == 0,
                 "Deleting SecurityGroupRule should delete it: {0}".format(