|
|
@@ -1,12 +1,12 @@
|
|
|
"""
|
|
|
DataTypes used by this provider
|
|
|
"""
|
|
|
+import collections
|
|
|
import inspect
|
|
|
import json
|
|
|
import logging
|
|
|
import time
|
|
|
|
|
|
-
|
|
|
from azure.common import AzureException
|
|
|
from azure.mgmt.network.models import NetworkSecurityGroup
|
|
|
|
|
|
@@ -14,12 +14,12 @@ from cloudbridge.cloud.base.resources import BaseAttachmentInfo, \
|
|
|
BaseBucket, BaseBucketContainer, BaseBucketObject, BaseFloatingIP, \
|
|
|
BaseInstance, BaseInternetGateway, BaseKeyPair, BaseLaunchConfig, \
|
|
|
BaseMachineImage, BaseNetwork, BasePlacementZone, BaseRegion, BaseRouter, \
|
|
|
- BaseSnapshot, BaseSubnet, BaseVMFirewall, BaseVMFirewallRule, BaseVMType, \
|
|
|
- BaseVolume, ClientPagedResultList
|
|
|
+ BaseSnapshot, BaseSubnet, BaseVMFirewall, BaseVMFirewallRule, \
|
|
|
+ BaseVMFirewallRuleContainer, BaseVMType, BaseVolume, ClientPagedResultList
|
|
|
from cloudbridge.cloud.interfaces import InstanceState, VolumeState
|
|
|
from cloudbridge.cloud.interfaces.resources import Instance, \
|
|
|
MachineImageState, NetworkState, RouterState, \
|
|
|
- SnapshotState, SubnetState
|
|
|
+ SnapshotState, SubnetState, TrafficDirection
|
|
|
|
|
|
from msrestazure.azure_exceptions import CloudError
|
|
|
|
|
|
@@ -74,6 +74,7 @@ class AzureVMFirewall(BaseVMFirewall):
|
|
|
self._vm_firewall = vm_firewall
|
|
|
if not self._vm_firewall.tags:
|
|
|
self._vm_firewall.tags = {}
|
|
|
+ self._rule_container = AzureVMFirewallRuleContainer(provider, self)
|
|
|
|
|
|
@property
|
|
|
def network_id(self):
|
|
|
@@ -112,15 +113,7 @@ class AzureVMFirewall(BaseVMFirewall):
|
|
|
|
|
|
@property
|
|
|
def rules(self):
|
|
|
- """
|
|
|
- The default rules are not returned, only custom rules.
|
|
|
- """
|
|
|
- vm_firewall_rules = []
|
|
|
- for custom_rule in self._vm_firewall.security_rules:
|
|
|
- fw_custom_rule = AzureVMFirewallRule(self._provider,
|
|
|
- custom_rule, self)
|
|
|
- vm_firewall_rules.append(fw_custom_rule)
|
|
|
- return vm_firewall_rules
|
|
|
+ return self._rule_container
|
|
|
|
|
|
def delete(self):
|
|
|
try:
|
|
|
@@ -144,130 +137,139 @@ class AzureVMFirewall(BaseVMFirewall):
|
|
|
log.exception(cloudError.message)
|
|
|
# The security group no longer exists and cannot be refreshed.
|
|
|
|
|
|
- def add_rule(self, ip_protocol=None, from_port=None, to_port=None,
|
|
|
- cidr=None, src_dest_fw=None):
|
|
|
- if ip_protocol and from_port and to_port:
|
|
|
- return self._create_rule(ip_protocol, from_port, to_port, cidr)
|
|
|
+ def to_json(self):
|
|
|
+ attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
|
|
|
+ js = {k: v for (k, v) in attr if not k.startswith('_')}
|
|
|
+ json_rules = [r.to_json() for r in self.rules]
|
|
|
+ js['rules'] = [json.loads(r) for r in json_rules]
|
|
|
+ if js.get('network_id'):
|
|
|
+ js.pop('network_id') # Omit for consistency across cloud providers
|
|
|
+ return js
|
|
|
+
|
|
|
+
|
|
|
+class AzureVMFirewallRuleContainer(BaseVMFirewallRuleContainer):
|
|
|
+
|
|
|
+ def __init__(self, provider, firewall):
|
|
|
+ super(AzureVMFirewallRuleContainer, self).__init__(provider, firewall)
|
|
|
+
|
|
|
+ def list(self, limit=None, marker=None):
|
|
|
+ # pylint:disable=protected-access
|
|
|
+ rules = (
|
|
|
+ [AzureVMFirewallRule(self.firewall, rule) for rule
|
|
|
+ in self.firewall._vm_firewall.security_rules] +
|
|
|
+ [AzureVMFirewallRule(self.firewall, rule) for rule
|
|
|
+ in self.firewall._vm_firewall.default_security_rules
|
|
|
+ if rule.destination_address_prefix == "Internet"])
|
|
|
+ return ClientPagedResultList(self._provider, rules,
|
|
|
+ limit=limit, marker=marker)
|
|
|
+
|
|
|
+ def create(self, direction, protocol=None, from_port=None, to_port=None,
|
|
|
+ cidr=None, src_dest_fw=None):
|
|
|
+ if protocol and from_port and to_port:
|
|
|
+ return self._create_rule(direction, protocol, from_port,
|
|
|
+ to_port, cidr)
|
|
|
elif src_dest_fw:
|
|
|
result = None
|
|
|
fw = (self._provider.security.vm_firewalls.get(src_dest_fw)
|
|
|
if isinstance(src_dest_fw, str) else src_dest_fw)
|
|
|
for rule in fw.rules:
|
|
|
- result = self._create_rule(rule.ip_protocol, rule.from_port,
|
|
|
- rule.to_port, rule.cidr)
|
|
|
+ result = self._create_rule(
|
|
|
+ rule.direction, rule.protocol, rule.from_port,
|
|
|
+ rule.to_port, rule.cidr)
|
|
|
return result
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
- def _create_rule(self, ip_protocol, from_port, to_port, cidr):
|
|
|
+ def _create_rule(self, direction, protocol, from_port, to_port, cidr):
|
|
|
|
|
|
# If cidr is None, default values is set as 0.0.0.0/0
|
|
|
if not cidr:
|
|
|
cidr = '0.0.0.0/0'
|
|
|
|
|
|
- # If the fw with same parameters exist already,
|
|
|
- # then, it is returned instead of creating a new one.
|
|
|
- rule = self.get_rule(ip_protocol, from_port, to_port, cidr)
|
|
|
-
|
|
|
- if rule:
|
|
|
- return rule
|
|
|
-
|
|
|
- count = len(self.rules) + 1
|
|
|
+ count = len(self.firewall._vm_firewall.security_rules) + 1
|
|
|
rule_name = "Rule - " + str(count)
|
|
|
priority = count * 100
|
|
|
destination_port_range = str(from_port) + "-" + str(to_port)
|
|
|
source_port_range = '*'
|
|
|
destination_address_prefix = "*"
|
|
|
access = "Allow"
|
|
|
- direction = "Inbound"
|
|
|
- parameters = {"protocol": ip_protocol,
|
|
|
+ direction = ("Inbound" if direction == TrafficDirection.INBOUND
|
|
|
+ else "Outbound")
|
|
|
+ parameters = {"protocol": protocol,
|
|
|
"source_port_range": source_port_range,
|
|
|
"destination_port_range": destination_port_range,
|
|
|
"priority": priority,
|
|
|
"source_address_prefix": cidr,
|
|
|
- "destination_address_prefix":
|
|
|
- destination_address_prefix,
|
|
|
- "access": access, "direction": direction}
|
|
|
+ "destination_address_prefix": destination_address_prefix,
|
|
|
+ "access": access,
|
|
|
+ "direction": direction}
|
|
|
result = self._provider.azure_client. \
|
|
|
- create_vm_firewall_rule(self.id,
|
|
|
+ create_vm_firewall_rule(self.firewall.id,
|
|
|
rule_name, parameters)
|
|
|
- self._vm_firewall.security_rules.append(result)
|
|
|
- return AzureVMFirewallRule(self._provider, result, self)
|
|
|
-
|
|
|
- def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
|
|
|
- cidr=None, src_dest_fw=None):
|
|
|
- for rule in self.rules:
|
|
|
- if (rule.ip_protocol == ip_protocol and
|
|
|
- rule.from_port == str(from_port) and
|
|
|
- rule.to_port == str(to_port) and
|
|
|
- rule.cidr == cidr):
|
|
|
- return rule
|
|
|
- return None
|
|
|
+ # pylint:disable=protected-access
|
|
|
+ self.firewall._vm_firewall.security_rules.append(result)
|
|
|
+ return AzureVMFirewallRule(self.firewall, result)
|
|
|
|
|
|
- def to_json(self):
|
|
|
- attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
|
|
|
- js = {k: v for (k, v) in attr if not k.startswith('_')}
|
|
|
- json_rules = [r.to_json() for r in self.rules]
|
|
|
- js['rules'] = [json.loads(r) for r in json_rules]
|
|
|
- if js.get('network_id'):
|
|
|
- js.pop('network_id') # Omit for consistency across cloud providers
|
|
|
- return js
|
|
|
+
|
|
|
+# Tuple for port range
|
|
|
+PortRange = collections.namedtuple('PortRange', ['from_port', 'to_port'])
|
|
|
|
|
|
|
|
|
class AzureVMFirewallRule(BaseVMFirewallRule):
|
|
|
- def __init__(self, provider, rule, parent):
|
|
|
- super(AzureVMFirewallRule, self).__init__(provider, rule, parent)
|
|
|
+ def __init__(self, parent_fw, rule):
|
|
|
+ super(AzureVMFirewallRule, self).__init__(parent_fw, rule)
|
|
|
|
|
|
@property
|
|
|
def id(self):
|
|
|
return self._rule.name
|
|
|
|
|
|
+ @property
|
|
|
+ def direction(self):
|
|
|
+ return (TrafficDirection.INBOUND if self._rule.direction == "Inbound"
|
|
|
+ else TrafficDirection.OUTBOUND)
|
|
|
+
|
|
|
@property
|
|
|
def name(self):
|
|
|
return self._rule.name
|
|
|
|
|
|
@property
|
|
|
- def ip_protocol(self):
|
|
|
+ def protocol(self):
|
|
|
return self._rule.protocol
|
|
|
|
|
|
@property
|
|
|
def from_port(self):
|
|
|
- if self._rule.destination_port_range == '*':
|
|
|
- return self._rule.destination_port_range
|
|
|
- destination_port_range = self._rule.destination_port_range
|
|
|
- port_range_split = destination_port_range.split('-', 1)
|
|
|
- return port_range_split[0]
|
|
|
+ return self._port_range_tuple().from_port
|
|
|
|
|
|
@property
|
|
|
def to_port(self):
|
|
|
+ return self._port_range_tuple().to_port
|
|
|
+
|
|
|
+ def _port_range_tuple(self):
|
|
|
if self._rule.destination_port_range == '*':
|
|
|
- return self._rule.destination_port_range
|
|
|
+ return PortRange(1, 65535)
|
|
|
destination_port_range = self._rule.destination_port_range
|
|
|
port_range_split = destination_port_range.split('-', 1)
|
|
|
- return port_range_split[1]
|
|
|
+ return PortRange(int(port_range_split[0]), int(port_range_split[1]))
|
|
|
|
|
|
@property
|
|
|
def cidr(self):
|
|
|
return self._rule.source_address_prefix
|
|
|
|
|
|
@property
|
|
|
- def group(self):
|
|
|
- return self.parent
|
|
|
+ def src_dest_fw_id(self):
|
|
|
+ return self.firewall.id
|
|
|
|
|
|
- def to_json(self):
|
|
|
- attr = inspect.getmembers(self, lambda a: not (inspect.isroutine(a)))
|
|
|
- js = {k: v for (k, v) in attr if not k.startswith('_')}
|
|
|
- js['group'] = self.group.id if self.group else ''
|
|
|
- js['parent'] = self.parent.id if self.parent else ''
|
|
|
- return json.dumps(js, sort_keys=True)
|
|
|
+ @property
|
|
|
+ def src_dest_fw(self):
|
|
|
+ return self.firewall
|
|
|
|
|
|
def delete(self):
|
|
|
- vm_firewall = self.parent.name
|
|
|
+ vm_firewall = self.firewall.name
|
|
|
self._provider.azure_client. \
|
|
|
delete_vm_firewall_rule(self.id, vm_firewall)
|
|
|
- for i, o in enumerate(self.parent._vm_firewall.security_rules):
|
|
|
+ for i, o in enumerate(self.firewall._vm_firewall.security_rules):
|
|
|
if o.name == self.name:
|
|
|
- del self.parent._vm_firewall.security_rules[i]
|
|
|
+ del self.firewall._vm_firewall.security_rules[i]
|
|
|
break
|
|
|
|
|
|
|