Parcourir la source

Made security group rule methods uniform

Nuwan Goonasekera il y a 8 ans
Parent
commit
175ee982db

+ 81 - 19
cloudbridge/cloud/base/resources.py

@@ -627,7 +627,6 @@ class BaseVMFirewall(BaseCloudResource, VMFirewall):
         return (isinstance(other, VMFirewall) and
                 # pylint:disable=protected-access
                 self._provider == other._provider and
-                len(self.rules) == len(other.rules) and  # Shortcut
                 set(self.rules) == set(other.rules))
 
     def __ne__(self, other):
@@ -668,29 +667,85 @@ class BaseVMFirewall(BaseCloudResource, VMFirewall):
                                             self.id, self.name)
 
 
+class BaseVMFirewallRuleContainer(BasePageableObjectMixin):
+
+    def __init__(self, provider, firewall):
+        self.__provider = provider
+        self.firewall = firewall
+
+    @property
+    def _provider(self):
+        return self.__provider
+
+    def get(self, rule_id):
+        matches = [rule for rule in self if rule.id == rule_id]
+        if matches:
+            return matches[0]
+        else:
+            return None
+
+    def find(self, **kwargs):
+        matches = self
+
+        def filter_by(prop_name, rules):
+            prop_val = kwargs.pop(prop_name, None)
+            if prop_val:
+                match = [r for r in rules if getattr(r, prop_name) == prop_val]
+                return match
+            return rules
+
+        matches = filter_by('name', matches)
+        matches = filter_by('direction', matches)
+        matches = filter_by('protocol', matches)
+        matches = filter_by('from_port', matches)
+        matches = filter_by('to_port', matches)
+        matches = filter_by('cidr', matches)
+        matches = filter_by('src_dest_fw', matches)
+        matches = filter_by('src_dest_fw_id', matches)
+        limit = kwargs.pop('limit', None)
+        marker = kwargs.pop('marker', None)
+
+        return ClientPagedResultList(self._provider, matches,
+                                     limit=limit, marker=marker)
+
+    def delete(self, rule_id):
+        rule = self.get(rule_id)
+        if rule:
+            rule.delete()
+
+
 class BaseVMFirewallRule(BaseCloudResource, VMFirewallRule):
 
-    def __init__(self, provider, rule, parent):
-        super(BaseVMFirewallRule, self).__init__(provider)
+    def __init__(self, parent_fw, rule):
+        # pylint:disable=protected-access
+        super(BaseVMFirewallRule, self).__init__(
+            parent_fw._provider)
+        self.firewall = parent_fw
         self._rule = rule
-        self.parent = parent
 
+        # Cache name
+        self._name = "{0}-{1}-{2}-{3}-{4}-{5}".format(
+            self.direction, self.protocol, self.from_port, self.to_port,
+            self.cidr, self.src_dest_fw_id).lower()
+
+    @property
     def name(self):
-        """
-        VM firewall rules don't support names, so pass
-        """
-        pass
+        return self._name
 
     def __repr__(self):
-        return ("<CBVMFirewallRule: IP: {0}; from: {1}; to: {2}; grp: {3}>"
-                .format(self.ip_protocol, self.from_port, self.to_port,
-                        self.group))
+        return ("<CBVMFirewallRule: id: {0}; direction: {1}; protocol: {2};"
+                " from: {3}; to: {4}; cidr: {5}, src_dest_fw: {6}>"
+                .format(self.id, self.direction, self.protocol, self.from_port,
+                        self.to_port, self.cidr, self.src_dest_fw_id))
 
     def __eq__(self, other):
-        return self.ip_protocol == other.ip_protocol and \
-            self.from_port == other.from_port and \
-            self.to_port == other.to_port and \
-            self.cidr_ip == other.cidr_ip
+        return (isinstance(other, VMFirewallRule) and
+                self.direction == other.direction and
+                self.protocol == other.protocol and
+                self.from_port == other.from_port and
+                self.to_port == other.to_port and
+                self.cidr == other.cidr and
+                self.src_dest_fw_id == other.src_dest_fw_id)
 
     def __ne__(self, other):
         return not self.__eq__(other)
@@ -699,12 +754,19 @@ class BaseVMFirewallRule(BaseCloudResource, VMFirewallRule):
         """
         Return a hash-based interpretation of all of the object's field values.
 
-        This is requried for operations on hashed collections including
+        This is requeried for operations on hashed collections including
         ``set``, ``frozenset``, and ``dict``.
         """
-        return hash("{0}{1}{2}{3}{4}".format(self.ip_protocol, self.from_port,
-                                             self.to_port, self.cidr_ip,
-                                             self.group))
+        return hash("{0}{1}{2}{3}{4}{5}".format(
+            self.direction, self.protocol, self.from_port, self.to_port,
+            self.cidr, self.src_dest_fw_id))
+
+    def to_json(self):
+        attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
+        js = {k: v for (k, v) in attr if not k.startswith('_')}
+        js['src_dest_fw'] = self.src_dest_fw_id
+        js['firewall'] = self.firewall.id
+        return js
 
 
 class BasePlacementZone(BaseCloudResource, PlacementZone):

+ 15 - 2
cloudbridge/cloud/interfaces/exceptions.py

@@ -31,7 +31,7 @@ class InvalidConfigurationException(CloudBridgeBaseException):
 class ProviderConnectionException(CloudBridgeBaseException):
     """
     Marker interface for connection errors to a cloud provider.
-    Thrown when cloudbridge is unable to connect with a provider,
+    Thrown when CloudBridge is unable to connect with a provider,
     for example, when credentials are incorrect, or connection
     settings are invalid.
     """
@@ -41,8 +41,21 @@ class ProviderConnectionException(CloudBridgeBaseException):
 class InvalidNameException(CloudBridgeBaseException):
     """
     Marker interface for any attempt to set an invalid name on
-    a cloudbridge resource.An example would be setting uppercase
+    a CloudBridge resource.An example would be setting uppercase
     letters, which are not allowed in a resource name.
     """
     def __init__(self, msg):
         super(InvalidNameException, self).__init__(msg)
+
+
+class InvalidValueException(CloudBridgeBaseException):
+    """
+    Marker interface for any attempt to set an invalid value on a CloudBridge
+    resource.An example would be setting an unrecognised value for the
+    direction of a firewall rule other than TrafficDirection.INBOUND or
+    TrafficDirection.OUTBOUND.
+    """
+    def __init__(self, param, value):
+        super(InvalidNameException, self).__init__(
+            "Param %s has been given an unrecognised value %s" %
+            (param, value))

+ 120 - 34
cloudbridge/cloud/interfaces/resources.py

@@ -2,6 +2,7 @@
 Specifications for data objects exposed through a provider or service
 """
 from abc import ABCMeta, abstractmethod, abstractproperty
+from enum import Enum
 
 
 class CloudServiceType(object):
@@ -1706,30 +1707,71 @@ class VMFirewall(CloudResource):
         """
         pass
 
+
+class VMFirewallRuleContainer(PageableObjectMixin, CloudResource):
+    """
+    Base interface for Firewall rules.
+    """
+    __metaclass__ = ABCMeta
+
     @abstractmethod
-    def delete(self):
+    def get(self, rule_id):
         """
-        Delete this VM firewall.
+        Returns a firewall rule given its ID. Returns ``None`` if the
+        rule does not exist.
 
-        :rtype: ``bool``
-        :return: ``True`` if successful.
+        Example:
+
+        .. code-block:: python
+
+            fw = provider.security.vm_firewalls.get('my_fw_id')
+            rule = fw.rules.get('rule_id')
+            print(rule.id, rule.name)
+
+        :rtype: :class:`.FirewallRule`
+        :return:  a FirewallRule instance
         """
         pass
 
     @abstractmethod
-    def add_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_firewall=None):
+    def list(self, limit=None, marker=None):
+        """
+        List all firewall rules associated with this firewall.
+
+        :rtype: ``list`` of :class:`.FirewallRule`
+        :return:  list of Firewall rule objects
+        """
+        pass
+
+    @abstractmethod
+    def create(self,  direction, protocol=None, from_port=None,
+               to_port=None, cidr=None, src_dest_fw=None):
         """
         Create a VM firewall rule. If the rule already exists, simply
         returns it.
 
-        You need to pass in either ``src_firewall`` OR ``ip_protocol`` AND
+        Example:
+
+        .. code-block:: python
+            import TafficDirection from cloudbridge.cloud.interfaces.resources
+
+            fw = provider.security.vm_firewalls.get('my_fw_id')
+            fw.rules.create(TrafficDirection.INBOUND, protocol='tcp',
+                            from_port=80, to_port=80, cidr='10.0.0.0/16')
+            fw.rules.create(TrafficDirection.INBOUND, src_dest_fw=fw)
+            fw.rules.create(TrafficDirection.OUTBOUND, src_dest_fw=fw)
+
+        You need to pass in either ``src_dest_fw`` OR ``protocol`` AND
         ``from_port``, ``to_port``, ``cidr_ip``. In other words, either
         you are authorizing another group or you are authorizing some
-        ip-based rule.
+        IP-based rule.
+
+        :type direction: :class:``.TrafficDirection``
+        :param direction: Either ``TrafficDirection.INBOUND`` |
+                          ``TrafficDirection.OUTBOUND``
 
-        :type ip_protocol: ``str``
-        :param ip_protocol: Either ``tcp`` | ``udp`` | ``icmp``.
+        :type protocol: ``str``
+        :param protocol: Either ``tcp`` | ``udp`` | ``icmp``.
 
         :type from_port: ``int``
         :param from_port: The beginning port number you are enabling.
@@ -1737,30 +1779,30 @@ class VMFirewall(CloudResource):
         :type to_port: ``int``
         :param to_port: The ending port number you are enabling.
 
-        :type cidr_ip: ``str`` or list of ``str``
-        :param cidr_ip: The CIDR block you are providing access to.
+        :type cidr: ``str`` or list of ``str``
+        :param cidr: The CIDR block you are providing access to.
 
-        :type src_firewall: :class:`.VMFirewall`
-        :param src_firewall: The VM firewall object you are granting access to.
+        :type src_dest_fw: :class:`.VMFirewall`
+        :param src_dest_fw: The VM firewall object which is the
+                            source/destination of the traffic, depending on
+                            whether it's ingress/egress traffic.
 
         :rtype: :class:`.VMFirewallRule`
         :return: Rule object if successful or ``None``.
         """
         pass
 
-    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_firewall=None):
+    @abstractmethod
+    def find(self, **kwargs):
         """
-        Get a VM firewall rule with the specified parameters.
+        Find a firewall rule associated with your account filtered by the given
+        parameters.
 
-        You need to pass in either ``src_firewall`` OR ``ip_protocol`` AND
-        ``from_port``, ``to_port``, and ``cidr_ip``. Note that when retrieving
-        a group rule, this method will return only one rule although possibly
-        several rules exist for the group rule. In that case, use the
-        ``.rules`` property and filter the results as desired.
+        :type name: str
+        :param name: The name of the VM firewall to retrieve.
 
-        :type ip_protocol: ``str``
-        :param ip_protocol: Either ``tcp`` | ``udp`` | ``icmp``.
+        :type protocol: ``str``
+        :param protocol: Either ``tcp`` | ``udp`` | ``icmp``.
 
         :type from_port: ``int``
         :param from_port: The beginning port number you are enabling.
@@ -1768,17 +1810,40 @@ class VMFirewall(CloudResource):
         :type to_port: ``int``
         :param to_port: The ending port number you are enabling.
 
-        :type cidr_ip: ``str`` or list of ``str``
-        :param cidr_ip: The CIDR block you are providing access to.
+        :type cidr: ``str`` or list of ``str``
+        :param cidr: The CIDR block you are providing access to.
 
-        :type src_firewall: :class:`.VMFirewall`
-        :param src_firewall: The VM firewall object you are granting access to.
+        :type src_dest_fw: :class:`.VMFirewall`
+        :param src_dest_fw: The VM firewall object which is the
+                            source/destination of the traffic, depending on
+                            whether it's ingress/egress traffic.
 
-        :rtype: :class:`.VMFirewallRule`
-        :return: Rule object if one can be found or ``None``.
+        :type src_dest_fw_id: :class:`.str`
+        :param src_dest_fw_id: The VM firewall id which is the
+                               source/destination of the traffic, depending on
+                               whether it's ingress/egress traffic.
+
+        :rtype: list of :class:`VMFirewallRule`
+        :return: A list of VMFirewall objects or an empty list if none
+                 found.
         """
         pass
 
+    @abstractmethod
+    def delete(self, rule_id):
+        """
+        Delete an existing VMFirewall rule.
+
+        :type rule_id: str
+        :param rule_id: The VM firewall rule to be deleted.
+        """
+        pass
+
+
+class TrafficDirection(Enum):
+    INBOUND = 'inbound'
+    OUTBOUND = 'outbound'
+
 
 class VMFirewallRule(CloudResource):
 
@@ -1788,7 +1853,18 @@ class VMFirewallRule(CloudResource):
     __metaclass__ = ABCMeta
 
     @abstractproperty
-    def ip_protocol(self):
+    def direction(self):
+        """
+        Direction of traffic to which this rule applies.
+        Either TrafficDirection.INBOUND | TrafficDirection.OUTBOUND
+
+        :rtype: ``str``
+        :return: Direction of traffic to which this rule applies
+        """
+        pass
+
+    @abstractproperty
+    def protocol(self):
         """
         IP protocol used. Either ``tcp`` | ``udp`` | ``icmp``.
 
@@ -1818,7 +1894,7 @@ class VMFirewallRule(CloudResource):
         pass
 
     @abstractproperty
-    def cidr_ip(self):
+    def cidr(self):
         """
         CIDR block this VM firewall is providing access to.
 
@@ -1828,12 +1904,22 @@ class VMFirewallRule(CloudResource):
         pass
 
     @abstractproperty
-    def group(self):
+    def src_dest_fw_id(self):
+        """
+        VM firewall id given access permissions by this rule.
+
+        :rtype: ``str``
+        :return: The VM firewall granted access.
+        """
+        pass
+
+    @abstractproperty
+    def src_dest_fw(self):
         """
         VM firewall given access permissions by this rule.
 
         :rtype: :class:``.VMFirewall``
-        :return: The VM firewall with granting access.
+        :return: The VM firewall granted access.
         """
         pass
 

+ 0 - 6
cloudbridge/cloud/interfaces/services.py

@@ -1181,12 +1181,6 @@ class VMFirewallService(PageableObjectMixin, CloudService):
 
         :type group_id: str
         :param group_id: The VM firewall ID to be deleted.
-
-        :rtype: ``bool``
-        :return:  ``True`` if the VM firewall does not exist, ``False``
-                  otherwise. Note that this implies that the group may not have
-                  been deleted by this method but instead has not existed in
-                  the first place.
         """
         pass
 

+ 102 - 89
cloudbridge/cloud/providers/aws/resources.py

@@ -23,9 +23,11 @@ from cloudbridge.cloud.base.resources import BaseSnapshot
 from cloudbridge.cloud.base.resources import BaseSubnet
 from cloudbridge.cloud.base.resources import BaseVMFirewall
 from cloudbridge.cloud.base.resources import BaseVMFirewallRule
+from cloudbridge.cloud.base.resources import BaseVMFirewallRuleContainer
 from cloudbridge.cloud.base.resources import BaseVMType
 from cloudbridge.cloud.base.resources import BaseVolume
 from cloudbridge.cloud.base.resources import ClientPagedResultList
+from cloudbridge.cloud.interfaces.exceptions import InvalidValueException
 from cloudbridge.cloud.interfaces.resources import GatewayState
 from cloudbridge.cloud.interfaces.resources import InstanceState
 from cloudbridge.cloud.interfaces.resources import MachineImageState
@@ -33,7 +35,7 @@ from cloudbridge.cloud.interfaces.resources import NetworkState
 from cloudbridge.cloud.interfaces.resources import RouterState
 from cloudbridge.cloud.interfaces.resources import SnapshotState
 from cloudbridge.cloud.interfaces.resources import SubnetState
-from cloudbridge.cloud.interfaces.resources import VMFirewall
+from cloudbridge.cloud.interfaces.resources import TrafficDirection
 from cloudbridge.cloud.interfaces.resources import VolumeState
 
 from .helpers import find_tag_value
@@ -548,6 +550,7 @@ class AWSVMFirewall(BaseVMFirewall):
 
     def __init__(self, provider, _vm_firewall):
         super(AWSVMFirewall, self).__init__(provider, _vm_firewall)
+        self._rule_container = AWSVMFirewallRuleContainer(provider, self)
 
     @property
     def name(self):
@@ -559,60 +562,10 @@ class AWSVMFirewall(BaseVMFirewall):
 
     @property
     def rules(self):
-        return [AWSVMFirewallRule(self._provider, r, self)
-                for r in self._vm_firewall.ip_permissions]
+        return self._rule_container
 
-    def add_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_firewall=None):
-        try:
-            src_firewall_id = (
-                src_firewall.id if isinstance(src_firewall, VMFirewall)
-                else src_firewall)
-
-            ip_perm_entry = {
-                'IpProtocol': ip_protocol,
-                'FromPort': from_port,
-                'ToPort': to_port,
-                'IpRanges': [{'CidrIp': cidr_ip}] if cidr_ip else None,
-                'UserIdGroupPairs': [{
-                    'GroupId': src_firewall_id}
-                ] if src_firewall_id else None
-            }
-            # Filter out empty values to please Boto
-            ip_perms = [trim_empty_params(ip_perm_entry)]
-            self._vm_firewall.authorize_ingress(IpPermissions=ip_perms)
-            self._vm_firewall.reload()
-            return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
-                                 src_firewall_id)
-        except ClientError as ec2e:
-            if ec2e.response['Error']['Code'] == "InvalidPermission.Duplicate":
-                return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
-                                     src_firewall)
-            else:
-                raise ec2e
-
-    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_firewall=None):
-        src_firewall_id = (
-            src_firewall.id if isinstance(src_firewall, VMFirewall)
-            else src_firewall)
-        for rule in self._vm_firewall.ip_permissions:
-            if ip_protocol and rule['IpProtocol'] != ip_protocol:
-                continue
-            elif from_port and rule['FromPort'] != from_port:
-                continue
-            elif to_port and rule['ToPort'] != to_port:
-                continue
-            elif cidr_ip:
-                if cidr_ip not in [x['CidrIp'] for x in rule['IpRanges']]:
-                    continue
-            elif src_firewall_id:
-                if src_firewall_id not in [
-                    group_pair.get('GroupId') for group_pair in
-                        rule.get('UserIdGroupPairs', [])]:
-                    continue
-            return AWSVMFirewallRule(self._provider, rule, self)
-        return None
+    def refresh(self):
+        self._vm_firewall.reload()
 
     def to_json(self):
         attr = inspect.getmembers(self, lambda a: not inspect.isroutine(a))
@@ -624,77 +577,137 @@ class AWSVMFirewall(BaseVMFirewall):
         return js
 
 
+class AWSVMFirewallRuleContainer(BaseVMFirewallRuleContainer):
+
+    def __init__(self, provider, firewall):
+        super(AWSVMFirewallRuleContainer, self).__init__(provider, firewall)
+
+    def list(self, limit=None, marker=None):
+        # pylint:disable=protected-access
+        rules = [AWSVMFirewallRule(self.firewall,
+                                   TrafficDirection.INBOUND, r)
+                 for r in self.firewall._vm_firewall.ip_permissions]
+        rules = rules + [
+            AWSVMFirewallRule(
+                self.firewall, TrafficDirection.OUTBOUND, r)
+            for r in self.firewall._vm_firewall.ip_permissions_egress]
+        return ClientPagedResultList(self._provider, rules,
+                                     limit=limit, marker=marker)
+
+    def create(self,  direction, protocol=None, from_port=None,
+               to_port=None, cidr=None, src_dest_fw=None):
+        src_dest_fw_id = (
+            src_dest_fw.id if isinstance(src_dest_fw, AWSVMFirewall)
+            else src_dest_fw)
+
+        ip_perm_entry = AWSVMFirewallRule._construct_ip_perms(
+            protocol, from_port, to_port, cidr, src_dest_fw_id)
+        # Filter out empty values to please Boto
+        ip_perms = [trim_empty_params(ip_perm_entry)]
+
+        try:
+            if direction == TrafficDirection.INBOUND:
+                # pylint:disable=protected-access
+                self.firewall._vm_firewall.authorize_ingress(
+                    IpPermissions=ip_perms)
+            elif direction == TrafficDirection.OUTBOUND:
+                # pylint:disable=protected-access
+                self.firewall._vm_firewall.authorize_egress(
+                    IpPermissions=ip_perms)
+            else:
+                raise InvalidValueException("direction", direction)
+            self.firewall.refresh()
+            return AWSVMFirewallRule(self.firewall, direction, ip_perm_entry)
+        except ClientError as ec2e:
+            if ec2e.response['Error']['Code'] == "InvalidPermission.Duplicate":
+                return AWSVMFirewallRule(
+                    self.firewall, direction, ip_perm_entry)
+            else:
+                raise ec2e
+
+
 class AWSVMFirewallRule(BaseVMFirewallRule):
 
-    def __init__(self, provider, rule, parent):
-        super(AWSVMFirewallRule, self).__init__(provider, rule, parent)
+    def __init__(self, parent_fw, direction, rule):
+        self._direction = direction
+        super(AWSVMFirewallRule, self).__init__(parent_fw, rule)
+
+        # cache id
+        md5 = hashlib.md5()
+        md5.update(self._name.encode('ascii'))
+        self._id = md5.hexdigest()
 
     @property
     def id(self):
-        md5 = hashlib.md5()
-        md5.update("{0}-{1}-{2}-{3}".format(
-            self.ip_protocol, self.from_port, self.to_port, self.cidr_ip)
-            .encode('ascii'))
-        return md5.hexdigest()
+        return self._id
+
+    @property
+    def direction(self):
+        return self._direction
 
     @property
-    def ip_protocol(self):
+    def protocol(self):
         return self._rule.get('IpProtocol')
 
     @property
     def from_port(self):
-        return self._rule.get('FromPort', 0)
+        return self._rule.get('FromPort')
 
     @property
     def to_port(self):
-        return self._rule.get('ToPort', 0)
+        return self._rule.get('ToPort')
 
     @property
-    def cidr_ip(self):
-        if len(self._rule.get('IpRanges', [])) > 0:
+    def cidr(self):
+        if len(self._rule.get('IpRanges') or []) > 0:
             return self._rule['IpRanges'][0].get('CidrIp')
         return None
 
     @property
-    def group_id(self):
-        if len(self._rule['UserIdGroupPairs']) > 0:
+    def src_dest_fw_id(self):
+        if len(self._rule.get('UserIdGroupPairs') or []) > 0:
             return self._rule['UserIdGroupPairs'][0]['GroupId']
         else:
             return None
 
     @property
-    def group(self):
-        if self.group_id:
+    def src_dest_fw(self):
+        if self.src_dest_fw_id:
             return AWSVMFirewall(
                 self._provider,
-                self._provider.ec2_conn.SecurityGroup(self.group_id))
+                self._provider.ec2_conn.SecurityGroup(self.src_dest_fw_id))
         else:
             return None
 
-    def to_json(self):
-        attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
-        js = {k: v for (k, v) in attr if not k.startswith('_')}
-        js['group'] = self.group.id if self.group else ''
-        js['parent'] = self.parent.id if self.parent else ''
-        return js
-
-    def delete(self):
-
-        ip_perm_entry = {
-            'IpProtocol': self.ip_protocol,
-            'FromPort': self.from_port,
-            'ToPort': self.to_port,
-            'IpRanges': [{'CidrIp': self.cidr_ip}] if self.cidr_ip else None,
+    @staticmethod
+    def _construct_ip_perms(protocol, from_port, to_port, cidr,
+                            src_dest_fw_id):
+        return {
+            'IpProtocol': protocol,
+            'FromPort': from_port,
+            'ToPort': to_port,
+            'IpRanges': [{'CidrIp': cidr}] if cidr else None,
             'UserIdGroupPairs': [{
-                'GroupId': self.group_id}
-            ] if self.group_id else None
+                'GroupId': src_dest_fw_id}
+            ] if src_dest_fw_id else None
         }
 
+    def delete(self):
+        ip_perm_entry = self._construct_ip_perms(
+            self.protocol, self.from_port, self.to_port,
+            self.cidr, self.src_dest_fw_id)
+
         # Filter out empty values to please Boto
         ip_perms = [trim_empty_params(ip_perm_entry)]
 
-        self.parent._vm_firewall.revoke_ingress(IpPermissions=ip_perms)
-        self.parent._vm_firewall.reload()
+        # pylint:disable=protected-access
+        if self.direction == TrafficDirection.INBOUND:
+            self.firewall._vm_firewall.revoke_ingress(
+                IpPermissions=ip_perms)
+        else:
+            self.firewall._vm_firewall.revoke_egress(
+                IpPermissions=ip_perms)
+        self.firewall.refresh()
 
 
 class AWSBucketObject(BaseBucketObject):

+ 73 - 104
cloudbridge/cloud/providers/openstack/resources.py

@@ -22,9 +22,11 @@ from cloudbridge.cloud.base.resources import BaseSnapshot
 from cloudbridge.cloud.base.resources import BaseSubnet
 from cloudbridge.cloud.base.resources import BaseVMFirewall
 from cloudbridge.cloud.base.resources import BaseVMFirewallRule
+from cloudbridge.cloud.base.resources import BaseVMFirewallRuleContainer
 from cloudbridge.cloud.base.resources import BaseVMType
 from cloudbridge.cloud.base.resources import BaseVolume
 from cloudbridge.cloud.base.resources import ClientPagedResultList
+from cloudbridge.cloud.interfaces.exceptions import InvalidValueException
 from cloudbridge.cloud.interfaces.resources import GatewayState
 from cloudbridge.cloud.interfaces.resources import InstanceState
 from cloudbridge.cloud.interfaces.resources import MachineImageState
@@ -32,7 +34,7 @@ from cloudbridge.cloud.interfaces.resources import NetworkState
 from cloudbridge.cloud.interfaces.resources import RouterState
 from cloudbridge.cloud.interfaces.resources import SnapshotState
 from cloudbridge.cloud.interfaces.resources import SubnetState
-from cloudbridge.cloud.interfaces.resources import VMFirewall
+from cloudbridge.cloud.interfaces.resources import TrafficDirection
 from cloudbridge.cloud.interfaces.resources import VolumeState
 from cloudbridge.cloud.providers.openstack import helpers as oshelpers
 
@@ -1003,6 +1005,7 @@ class OpenStackVMFirewall(BaseVMFirewall):
 
     def __init__(self, provider, vm_firewall):
         super(OpenStackVMFirewall, self).__init__(provider, vm_firewall)
+        self._rule_svc = OpenStackVMFirewallRuleContainer(provider, self)
 
     @property
     def network_id(self):
@@ -1015,95 +1018,11 @@ class OpenStackVMFirewall(BaseVMFirewall):
 
     @property
     def rules(self):
-        # Update SG object; otherwise, recently added rules do now show
-        self._vm_firewall = self._provider.nova.security_groups.get(
-            self.id)
-        return [OpenStackVMFirewallRule(self._provider, r, self)
-                for r in self._vm_firewall.rules]
-
-    def add_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_firewall=None):
-        """
-        Create a VM firewall rule.
-
-        You need to pass in either ``src_firewall`` OR ``ip_protocol`` AND
-        ``from_port``, ``to_port``, ``cidr_ip``.  In other words, either
-        you are authorizing another group or you are authorizing some
-        ip-based rule.
-
-        :type ip_protocol: str
-        :param ip_protocol: Either ``tcp`` | ``udp`` | ``icmp``
-
-        :type from_port: int
-        :param from_port: The beginning port number you are enabling
-
-        :type to_port: int
-        :param to_port: The ending port number you are enabling
-
-        :type cidr_ip: str or list of strings
-        :param cidr_ip: The CIDR block you are providing access to.
-
-        :type src_firewall: ``object`` of :class:`.VMFirewall`
-        :param src_firewall: The VM firewall you are granting access to.
-
-        :rtype: :class:``.VMFirewallRule``
-        :return: Rule object if successful or ``None``.
-        """
-        if src_firewall:
-            if not isinstance(src_firewall, VMFirewall):
-                src_firewall = self._provider.security.vm_firewalls.get(
-                    src_firewall)
-            existing_rule = self.get_rule(ip_protocol=ip_protocol,
-                                          from_port=from_port,
-                                          to_port=to_port,
-                                          src_firewall=src_firewall)
-            if existing_rule:
-                return existing_rule
-
-            rule = self._provider.nova.security_group_rules.create(
-                parent_group_id=self.id,
-                ip_protocol=ip_protocol,
-                from_port=from_port,
-                to_port=to_port,
-                group_id=src_firewall.id)
-            if rule:
-                # We can only return one Rule so default to TCP (ie, last in
-                # the for loop above).
-                return OpenStackVMFirewallRule(self._provider,
-                                               rule.to_dict(), self)
-        else:
-            existing_rule = self.get_rule(ip_protocol=ip_protocol,
-                                          from_port=from_port,
-                                          to_port=to_port,
-                                          cidr_ip=cidr_ip)
-            if existing_rule:
-                return existing_rule
-
-            rule = self._provider.nova.security_group_rules.create(
-                parent_group_id=self.id,
-                ip_protocol=ip_protocol,
-                from_port=from_port,
-                to_port=to_port,
-                cidr=cidr_ip)
-            if rule:
-                return OpenStackVMFirewallRule(self._provider,
-                                               rule.to_dict(), self)
-        return None
+        return self._rule_svc
 
-    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_firewall=None):
-        # Update SG object; otherwise, recently added rules do not show
-        self._security_group = self._provider.nova.security_groups.get(
+    def refresh(self):
+        self._vm_firewall = self._provider.nova.security_groups.get(
             self.id)
-        for rule in self._vm_firewall.rules:
-            if (rule['ip_protocol'] == ip_protocol and
-                rule['from_port'] == from_port and
-                rule['to_port'] == to_port and
-                (rule['ip_range'].get('cidr') == cidr_ip or
-                 (rule['group'].get('name') == src_firewall.name
-                  if src_firewall else False))):
-                return OpenStackVMFirewallRule(self._provider, rule, self)
-        return None
 
     def to_json(self):
         attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
@@ -1113,49 +1032,99 @@ class OpenStackVMFirewall(BaseVMFirewall):
         return js
 
 
+class OpenStackVMFirewallRuleContainer(BaseVMFirewallRuleContainer):
+
+    def __init__(self, provider, firewall):
+        super(OpenStackVMFirewallRuleContainer, self).__init__(
+            provider, firewall)
+
+    def list(self, limit=None, marker=None):
+        # pylint:disable=protected-access
+        rules = [OpenStackVMFirewallRule(self.firewall, r)
+                 for r in self.firewall._vm_firewall.rules]
+        return ClientPagedResultList(self._provider, rules,
+                                     limit=limit, marker=marker)
+
+    def create(self,  direction, protocol=None, from_port=None,
+               to_port=None, cidr=None, src_dest_fw=None):
+        src_dest_fw_id = (src_dest_fw.id if isinstance(src_dest_fw,
+                                                       OpenStackVMFirewall)
+                          else src_dest_fw)
+
+        try:
+            if direction == TrafficDirection.INBOUND:
+                # pylint:disable=protected-access
+                rule = self._provider.nova.security_group_rules.create(
+                                parent_group_id=self.firewall.id,
+                                ip_protocol=protocol,
+                                from_port=from_port,
+                                to_port=to_port,
+                                cidr=cidr,
+                                group_id=src_dest_fw_id)
+            elif direction == TrafficDirection.OUTBOUND:
+                pass
+            else:
+                raise InvalidValueException("direction", direction)
+            self.firewall.refresh()
+            return OpenStackVMFirewallRule(self.firewall, rule.to_dict())
+        except novaex.BadRequest as e:
+            self.firewall.refresh()
+            existing = self.find(
+                direction=direction, protocol=protocol, from_port=from_port,
+                to_port=to_port, cidr=cidr, src_dest_fw_id=src_dest_fw_id)
+            if existing:
+                return existing[0]
+            else:
+                raise e
+
+
 class OpenStackVMFirewallRule(BaseVMFirewallRule):
 
-    def __init__(self, provider, rule, parent):
-        super(OpenStackVMFirewallRule, self).__init__(
-            provider, rule, parent)
+    def __init__(self, parent_fw, rule):
+        super(OpenStackVMFirewallRule, self).__init__(parent_fw, rule)
 
     @property
     def id(self):
         return self._rule.get('id')
 
     @property
-    def ip_protocol(self):
+    def direction(self):
+        return TrafficDirection.INBOUND
+
+    @property
+    def protocol(self):
         return self._rule.get('ip_protocol')
 
     @property
     def from_port(self):
-        return int(self._rule.get('from_port') or 0)
+        return self._rule.get('from_port')
 
     @property
     def to_port(self):
-        return int(self._rule.get('to_port') or 0)
+        return self._rule.get('to_port')
 
     @property
-    def cidr_ip(self):
+    def cidr(self):
         return self._rule.get('ip_range', {}).get('cidr')
 
     @property
-    def group(self):
+    def src_dest_fw_id(self):
+        fw = self.src_dest_fw
+        if fw:
+            return fw.id
+        return None
+
+    @property
+    def src_dest_fw(self):
         fw_name = self._rule.get('group', {}).get('name')
         if fw_name:
             fw = self._provider.security.vm_firewalls.find(name=fw_name)
             return fw[0] if fw else None
         return None
 
-    def to_json(self):
-        attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
-        js = {k: v for(k, v) in attr if not k.startswith('_')}
-        js['group'] = self.group.id if self.group else ''
-        js['parent'] = self.parent.id if self.parent else ''
-        return js
-
     def delete(self):
-        return self._provider.nova.security_group_rules.delete(self.id)
+        self._provider.nova.security_group_rules.delete(self.id)
+        self.firewall.refresh()
 
 
 class OpenStackBucketObject(BaseBucketObject):

+ 0 - 41
cloudbridge/cloud/providers/openstack/services.py

@@ -181,9 +181,6 @@ class OpenStackVMFirewallService(BaseVMFirewallService):
         super(OpenStackVMFirewallService, self).__init__(provider)
 
     def get(self, firewall_id):
-        """
-        Returns a VMFirewall given its id.
-        """
         try:
             return OpenStackVMFirewall(
                 self.provider,
@@ -192,13 +189,6 @@ class OpenStackVMFirewallService(BaseVMFirewallService):
             return None
 
     def list(self, limit=None, marker=None):
-        """
-        List all VM firewalls associated with this account.
-
-        :rtype: ``list`` of :class:`.VMFirewall`
-        :return:  list of VMFirewall objects
-        """
-
         firewalls = [OpenStackVMFirewall(self.provider, fw)
                      for fw in self.provider.nova.security_groups.list()]
 
@@ -206,22 +196,6 @@ class OpenStackVMFirewallService(BaseVMFirewallService):
                                      limit=limit, marker=marker)
 
     def create(self, name, description, network_id):
-        """
-        Create a new VM firewall under the current account.
-
-        :type name: str
-        :param name: The name of the new VM firewall.
-
-        :type description: str
-        :param description: The description of the new VM firewall.
-
-        :type  network_id: ``None``
-        :param network_id: Not applicable for OpenStack (yet) so any value is
-                           ignored.
-
-        :rtype: ``object`` of :class:`.VMFirewall`
-        :return: a VMFirewall object
-        """
         OpenStackVMFirewall.assert_valid_resource_name(name)
 
         sg = self.provider.nova.security_groups.create(name, description)
@@ -230,9 +204,6 @@ class OpenStackVMFirewallService(BaseVMFirewallService):
         return None
 
     def find(self, name, limit=None, marker=None):
-        """
-        Get all VM firewalls associated with your account.
-        """
         sgs = self.provider.nova.security_groups.findall(name=name)
         results = [OpenStackVMFirewall(self.provider, sg)
                    for sg in sgs]
@@ -240,18 +211,6 @@ class OpenStackVMFirewallService(BaseVMFirewallService):
                                      limit=limit, marker=marker)
 
     def delete(self, group_id):
-        """
-        Delete an existing VMFirewall.
-
-        :type group_id: str
-        :param group_id: The VM firewall ID to be deleted.
-
-        :rtype: ``bool``
-        :return:  ``True`` if the VM firewall does not exist, ``False``
-                  otherwise. Note that this implies that the group may not have
-                  been deleted by this method but instead has not existed in
-                  the first place.
-        """
         firewall = self.get(group_id)
         if firewall:
             firewall.delete()

+ 47 - 47
test/test_security_service.py

@@ -4,7 +4,9 @@ from test.helpers import ProviderTestBase
 from test.helpers import standard_interface_tests as sit
 
 from cloudbridge.cloud.interfaces.resources import KeyPair
+from cloudbridge.cloud.interfaces.resources import TrafficDirection
 from cloudbridge.cloud.interfaces.resources import VMFirewall
+from cloudbridge.cloud.interfaces.resources import VMFirewallRule
 
 
 class CloudSecurityServiceTestCase(ProviderTestBase):
@@ -78,38 +80,33 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
 
             self.assertEqual(name, fw.description)
 
-            rule = fw.add_rule(ip_protocol='tcp', from_port=1111, to_port=1111,
-                               cidr_ip='0.0.0.0/0')
-            found_rule = fw.get_rule(ip_protocol='tcp', from_port=1111,
-                                     to_port=1111, cidr_ip='0.0.0.0/0')
-            self.assertTrue(
-                rule == found_rule,
-                "Expected rule {0} not found in VM firewall: {0}".format(
-                    rule, fw.rules))
-
-            object_keys = (
-                fw.rules[0].ip_protocol,
-                fw.rules[0].from_port,
-                fw.rules[0].to_port)
-            self.assertTrue(
-                all(str(key) in repr(fw.rules[0]) for key in object_keys),
-                "repr(obj) should contain ip_protocol, form_port, and to_port"
-                " so that the object can be reconstructed, but does not:"
-                " {0}; {1}".format(fw.rules[0], object_keys))
-            self.assertTrue(
-                fw == fw,
-                "The same VM firewalls should be equal?")
-            self.assertFalse(
-                fw != fw,
-                "The same VM firewalls should still be equal?")
+    @helpers.skipIfNoService(['security.vm_firewalls'])
+    def test_crud_vm_firewall_rules(self):
+        name = 'cb_crudfw_rules-{0}'.format(helpers.get_uuid())
 
-        sit.check_delete(self, self.provider.security.vm_firewalls, fw)
-        fwl = self.provider.security.vm_firewalls.list()
-        found_fw = [f for f in fwl if f.name == name]
-        self.assertTrue(
-            len(found_fw) == 0,
-            "VM firewall {0} should have been deleted but still exists."
-            .format(name))
+        # Declare these variables and late binding will allow
+        # the cleanup method access to the most current values
+        net = None
+        with helpers.cleanup_action(lambda: helpers.cleanup_test_resources(
+                network=net)):
+            net, _ = helpers.create_test_network(self.provider, name)
+
+            fw = None
+            with helpers.cleanup_action(lambda: fw.delete()):
+                fw = self.provider.security.vm_firewalls.create(
+                    name=name, description=name, network_id=net.id)
+
+                def create_fw_rule(name):
+                    return fw.rules.create(
+                        direction=TrafficDirection.INBOUND, protocol='tcp',
+                        from_port=1111, to_port=1111, cidr='0.0.0.0/0')
+
+                def cleanup_fw_rule(rule):
+                    rule.delete()
+
+                sit.check_crud(self, fw.rules, VMFirewallRule, "cb_crudfwrule",
+                               create_fw_rule, cleanup_fw_rule,
+                               skip_name_check=True)
 
     @helpers.skipIfNoService(['security.vm_firewalls'])
     def test_vm_firewall_rule_add_twice(self):
@@ -126,15 +123,14 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             fw = self.provider.security.vm_firewalls.create(
                 name=name, description=name, network_id=net.id)
 
-            rule = fw.add_rule(ip_protocol='tcp', from_port=1111, to_port=1111,
-                               cidr_ip='0.0.0.0/0')
+            rule = fw.rules.create(
+                direction=TrafficDirection.INBOUND, protocol='tcp',
+                from_port=1111, to_port=1111, cidr='0.0.0.0/0')
             # attempting to add the same rule twice should succeed
-            same_rule = fw.add_rule(ip_protocol='tcp', from_port=1111,
-                                    to_port=1111, cidr_ip='0.0.0.0/0')
-            self.assertTrue(
-                rule == same_rule,
-                "Expected rule {0} not found in VM firewall: {0}".format(
-                    same_rule, fw.rules))
+            same_rule = fw.rules.create(
+                direction=TrafficDirection.INBOUND, protocol='tcp',
+                from_port=1111, to_port=1111, cidr='0.0.0.0/0')
+            self.assertEqual(rule, same_rule)
 
     @helpers.skipIfNoService(['security.vm_firewalls'])
     def test_vm_firewall_group_rule(self):
@@ -149,21 +145,25 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             net, _ = helpers.create_test_network(self.provider, name)
             fw = self.provider.security.vm_firewalls.create(
                 name=name, description=name, network_id=net.id)
+            rules = list(fw.rules)
             self.assertTrue(
-                len(fw.rules) == 0,
-                "Expected no VM firewall group rule. Got {0}."
-                .format(fw.rules))
-            rule = fw.add_rule(src_firewall=fw, ip_protocol='tcp', from_port=1,
-                               to_port=65535)
+                len(rules) == 1, "Expected a single VM firewall rule allowing"
+                " all outbound traffic. Got {0}.".format(rules))
+            self.assertEqual(
+                rules[0].direction, TrafficDirection.OUTBOUND,
+                "Expected rule to be outbound. Got {0}.".format(rules))
+            rule = fw.rules.create(
+                direction=TrafficDirection.INBOUND, src_dest_fw=fw,
+                protocol='tcp', from_port=1, to_port=65535)
             self.assertTrue(
-                rule.group.name == name,
+                rule.src_dest_fw.name == name,
                 "Expected VM firewall rule name {0}. Got {1}."
-                .format(name, rule.group.name))
+                .format(name, rule.src_dest_fw.name))
             for r in fw.rules:
                 r.delete()
             fw = self.provider.security.vm_firewalls.get(fw.id)  # update
             self.assertTrue(
-                len(fw.rules) == 0,
+                len(list(fw.rules)) == 0,
                 "Deleting VMFirewallRule should delete it: {0}".format(
                     fw.rules))
         fwl = self.provider.security.vm_firewalls.list()