Pārlūkot izejas kodu

Moved security group rules under rule container

Nuwan Goonasekera 8 gadi atpakaļ
vecāks
revīzija
b1d4d225f1

+ 77 - 75
cloudbridge/cloud/providers/azure/resources.py

@@ -1,12 +1,12 @@
 """
 DataTypes used by this provider
 """
+import collections
 import inspect
 import json
 import logging
 import time
 
-
 from azure.common import AzureException
 from azure.mgmt.network.models import NetworkSecurityGroup
 
@@ -14,12 +14,12 @@ from cloudbridge.cloud.base.resources import BaseAttachmentInfo, \
     BaseBucket, BaseBucketContainer, BaseBucketObject, BaseFloatingIP, \
     BaseInstance, BaseInternetGateway, BaseKeyPair, BaseLaunchConfig, \
     BaseMachineImage, BaseNetwork, BasePlacementZone, BaseRegion, BaseRouter, \
-    BaseSnapshot, BaseSubnet, BaseVMFirewall, BaseVMFirewallRule, BaseVMType, \
-    BaseVolume, ClientPagedResultList
+    BaseSnapshot, BaseSubnet, BaseVMFirewall, BaseVMFirewallRule, \
+    BaseVMFirewallRuleContainer, BaseVMType, BaseVolume, ClientPagedResultList
 from cloudbridge.cloud.interfaces import InstanceState, VolumeState
 from cloudbridge.cloud.interfaces.resources import Instance, \
     MachineImageState, NetworkState, RouterState, \
-    SnapshotState, SubnetState
+    SnapshotState, SubnetState, TrafficDirection
 
 from msrestazure.azure_exceptions import CloudError
 
@@ -74,6 +74,7 @@ class AzureVMFirewall(BaseVMFirewall):
         self._vm_firewall = vm_firewall
         if not self._vm_firewall.tags:
             self._vm_firewall.tags = {}
+        self._rule_container = AzureVMFirewallRuleContainer(provider, self)
 
     @property
     def network_id(self):
@@ -112,15 +113,7 @@ class AzureVMFirewall(BaseVMFirewall):
 
     @property
     def rules(self):
-        """
-        The default rules are not returned, only custom rules.
-        """
-        vm_firewall_rules = []
-        for custom_rule in self._vm_firewall.security_rules:
-            fw_custom_rule = AzureVMFirewallRule(self._provider,
-                                                 custom_rule, self)
-            vm_firewall_rules.append(fw_custom_rule)
-        return vm_firewall_rules
+        return self._rule_container
 
     def delete(self):
         try:
@@ -144,130 +137,139 @@ class AzureVMFirewall(BaseVMFirewall):
             log.exception(cloudError.message)
             # The security group no longer exists and cannot be refreshed.
 
-    def add_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr=None, src_dest_fw=None):
-        if ip_protocol and from_port and to_port:
-            return self._create_rule(ip_protocol, from_port, to_port, cidr)
+    def to_json(self):
+        attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
+        js = {k: v for (k, v) in attr if not k.startswith('_')}
+        json_rules = [r.to_json() for r in self.rules]
+        js['rules'] = [json.loads(r) for r in json_rules]
+        if js.get('network_id'):
+            js.pop('network_id')  # Omit for consistency across cloud providers
+        return js
+
+
+class AzureVMFirewallRuleContainer(BaseVMFirewallRuleContainer):
+
+    def __init__(self, provider, firewall):
+        super(AzureVMFirewallRuleContainer, self).__init__(provider, firewall)
+
+    def list(self, limit=None, marker=None):
+        # pylint:disable=protected-access
+        rules = (
+            [AzureVMFirewallRule(self.firewall, rule) for rule
+             in self.firewall._vm_firewall.security_rules] +
+            [AzureVMFirewallRule(self.firewall, rule) for rule
+             in self.firewall._vm_firewall.default_security_rules
+             if rule.destination_address_prefix == "Internet"])
+        return ClientPagedResultList(self._provider, rules,
+                                     limit=limit, marker=marker)
+
+    def create(self, direction, protocol=None, from_port=None, to_port=None,
+               cidr=None, src_dest_fw=None):
+        if protocol and from_port and to_port:
+            return self._create_rule(direction, protocol, from_port,
+                                     to_port, cidr)
         elif src_dest_fw:
             result = None
             fw = (self._provider.security.vm_firewalls.get(src_dest_fw)
                   if isinstance(src_dest_fw, str) else src_dest_fw)
             for rule in fw.rules:
-                result = self._create_rule(rule.ip_protocol, rule.from_port,
-                                           rule.to_port, rule.cidr)
+                result = self._create_rule(
+                    rule.direction, rule.protocol, rule.from_port,
+                    rule.to_port, rule.cidr)
             return result
         else:
             return None
 
-    def _create_rule(self, ip_protocol, from_port, to_port, cidr):
+    def _create_rule(self, direction, protocol, from_port, to_port, cidr):
 
         # If cidr is None, default values is set as 0.0.0.0/0
         if not cidr:
             cidr = '0.0.0.0/0'
 
-        # If the fw with same parameters exist already,
-        # then, it is returned instead of creating a new one.
-        rule = self.get_rule(ip_protocol, from_port, to_port, cidr)
-
-        if rule:
-            return rule
-
-        count = len(self.rules) + 1
+        count = len(self.firewall._vm_firewall.security_rules) + 1
         rule_name = "Rule - " + str(count)
         priority = count * 100
         destination_port_range = str(from_port) + "-" + str(to_port)
         source_port_range = '*'
         destination_address_prefix = "*"
         access = "Allow"
-        direction = "Inbound"
-        parameters = {"protocol": ip_protocol,
+        direction = ("Inbound" if direction == TrafficDirection.INBOUND
+                     else "Outbound")
+        parameters = {"protocol": protocol,
                       "source_port_range": source_port_range,
                       "destination_port_range": destination_port_range,
                       "priority": priority,
                       "source_address_prefix": cidr,
-                      "destination_address_prefix":
-                          destination_address_prefix,
-                      "access": access, "direction": direction}
+                      "destination_address_prefix": destination_address_prefix,
+                      "access": access,
+                      "direction": direction}
         result = self._provider.azure_client. \
-            create_vm_firewall_rule(self.id,
+            create_vm_firewall_rule(self.firewall.id,
                                     rule_name, parameters)
-        self._vm_firewall.security_rules.append(result)
-        return AzureVMFirewallRule(self._provider, result, self)
-
-    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr=None, src_dest_fw=None):
-        for rule in self.rules:
-            if (rule.ip_protocol == ip_protocol and
-                rule.from_port == str(from_port) and
-                rule.to_port == str(to_port) and
-                    rule.cidr == cidr):
-                return rule
-        return None
+        # pylint:disable=protected-access
+        self.firewall._vm_firewall.security_rules.append(result)
+        return AzureVMFirewallRule(self.firewall, result)
 
-    def to_json(self):
-        attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
-        js = {k: v for (k, v) in attr if not k.startswith('_')}
-        json_rules = [r.to_json() for r in self.rules]
-        js['rules'] = [json.loads(r) for r in json_rules]
-        if js.get('network_id'):
-            js.pop('network_id')  # Omit for consistency across cloud providers
-        return js
+
+# Tuple for port range
+PortRange = collections.namedtuple('PortRange', ['from_port', 'to_port'])
 
 
 class AzureVMFirewallRule(BaseVMFirewallRule):
-    def __init__(self, provider, rule, parent):
-        super(AzureVMFirewallRule, self).__init__(provider, rule, parent)
+    def __init__(self, parent_fw, rule):
+        super(AzureVMFirewallRule, self).__init__(parent_fw, rule)
 
     @property
     def id(self):
         return self._rule.name
 
+    @property
+    def direction(self):
+        return (TrafficDirection.INBOUND if self._rule.direction == "Inbound"
+                else TrafficDirection.OUTBOUND)
+
     @property
     def name(self):
         return self._rule.name
 
     @property
-    def ip_protocol(self):
+    def protocol(self):
         return self._rule.protocol
 
     @property
     def from_port(self):
-        if self._rule.destination_port_range == '*':
-            return self._rule.destination_port_range
-        destination_port_range = self._rule.destination_port_range
-        port_range_split = destination_port_range.split('-', 1)
-        return port_range_split[0]
+        return self._port_range_tuple().from_port
 
     @property
     def to_port(self):
+        return self._port_range_tuple().to_port
+
+    def _port_range_tuple(self):
         if self._rule.destination_port_range == '*':
-            return self._rule.destination_port_range
+            return PortRange(1, 65535)
         destination_port_range = self._rule.destination_port_range
         port_range_split = destination_port_range.split('-', 1)
-        return port_range_split[1]
+        return PortRange(int(port_range_split[0]), int(port_range_split[1]))
 
     @property
     def cidr(self):
         return self._rule.source_address_prefix
 
     @property
-    def group(self):
-        return self.parent
+    def src_dest_fw_id(self):
+        return self.firewall.id
 
-    def to_json(self):
-        attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
-        js = {k: v for (k, v) in attr if not k.startswith('_')}
-        js['group'] = self.group.id if self.group else ''
-        js['parent'] = self.parent.id if self.parent else ''
-        return json.dumps(js, sort_keys=True)
+    @property
+    def src_dest_fw(self):
+        return self.firewall
 
     def delete(self):
-        vm_firewall = self.parent.name
+        vm_firewall = self.firewall.name
         self._provider.azure_client. \
             delete_vm_firewall_rule(self.id, vm_firewall)
-        for i, o in enumerate(self.parent._vm_firewall.security_rules):
+        for i, o in enumerate(self.firewall._vm_firewall.security_rules):
             if o.name == self.name:
-                del self.parent._vm_firewall.security_rules[i]
+                del self.firewall._vm_firewall.security_rules[i]
                 break
 
 

+ 7 - 7
cloudbridge/cloud/providers/azure/services.py

@@ -83,7 +83,7 @@ class AzureVMFirewallService(BaseVMFirewallService):
         """
         filters = {'Name': name}
         fws = [AzureVMFirewall(self.provider, vm_firewall)
-               for vm_firewall in azure_helpers.filter(
+               for vm_firewall in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_vm_firewall(), filters)]
 
         return ClientPagedResultList(self.provider, fws,
@@ -245,7 +245,7 @@ class AzureVolumeService(BaseVolumeService):
         """
         filters = {'Name': name}
         cb_vols = [AzureVolume(self.provider, volume)
-                   for volume in azure_helpers.filter(
+                   for volume in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_disks(), filters)]
         return ClientPagedResultList(self.provider, cb_vols,
                                      limit=limit, marker=marker)
@@ -324,7 +324,7 @@ class AzureSnapshotService(BaseSnapshotService):
         """
         filters = {'Name': name}
         cb_snapshots = [AzureSnapshot(self.provider, snapshot)
-                        for snapshot in azure_helpers.filter(
+                        for snapshot in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_snapshots(), filters)]
         return ClientPagedResultList(self.provider, cb_snapshots,
                                      limit=limit, marker=marker)
@@ -657,7 +657,7 @@ class AzureInstanceService(BaseInstanceService):
         """
         filtr = {'Name': name}
         instances = [AzureInstance(self.provider, inst)
-                     for inst in azure_helpers.filter(
+                     for inst in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_vm(), filtr)]
         return ClientPagedResultList(self.provider, instances,
                                      limit=limit, marker=marker)
@@ -686,7 +686,7 @@ class AzureImageService(BaseImageService):
         """
         filters = {'Name': name}
         cb_images = [AzureMachineImage(self.provider, image)
-                     for image in azure_helpers.filter(
+                     for image in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_images(), filters)]
         return ClientPagedResultList(self.provider, cb_images,
                                      limit=limit, marker=marker)
@@ -778,7 +778,7 @@ class AzureNetworkService(BaseNetworkService):
     def find(self, name, limit=None, marker=None):
         filters = {'Name': name}
         networks = [AzureNetwork(self.provider, network)
-                    for network in azure_helpers.filter(
+                    for network in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_networks(), filters)]
         return ClientPagedResultList(self.provider, networks,
                                      limit=limit, marker=marker)
@@ -1020,7 +1020,7 @@ class AzureRouterService(BaseRouterService):
     def find(self, name, limit=None, marker=None):
         filters = {'Name': name}
         routes = [AzureRouter(self.provider, route)
-                  for route in azure_helpers.filter(
+                  for route in azure_helpers.filter_by_tag(
                 self.provider.azure_client.list_route_tables(), filters)]
 
         return ClientPagedResultList(self.provider, routes,