Просмотр исходного кода

Implement firewall rule container

Also fix issues with sub-networks.
Ehsan Chiniforooshan 7 лет назад
Родитель
Сommit
343656f7e7

+ 5 - 7
cloudbridge/cloud/factory.py

@@ -83,13 +83,13 @@ class CloudProviderFactory(object):
         Imports and registers providers from the given module name.
         Imports and registers providers from the given module name.
         Raises an ImportError if the import does not succeed.
         Raises an ImportError if the import does not succeed.
         """
         """
-        log.info("Importing providers from %s", module_name)
+        log.debug("Importing providers from %s", module_name)
         module = importlib.import_module(
         module = importlib.import_module(
             "{0}.{1}".format(providers.__name__,
             "{0}.{1}".format(providers.__name__,
                              module_name))
                              module_name))
         classes = inspect.getmembers(module, inspect.isclass)
         classes = inspect.getmembers(module, inspect.isclass)
         for _, cls in classes:
         for _, cls in classes:
-            log.info("Registering the provider: %s", cls)
+            log.debug("Registering the provider: %s", cls)
             self.register_provider_class(cls)
             self.register_provider_class(cls)
 
 
     def list_providers(self):
     def list_providers(self):
@@ -110,7 +110,7 @@ class CloudProviderFactory(object):
         """
         """
         if not self.provider_list:
         if not self.provider_list:
             self.discover_providers()
             self.discover_providers()
-        log.info("List of available providers: %s", self.provider_list)
+        log.debug("List of available providers: %s", self.provider_list)
         return self.provider_list
         return self.provider_list
 
 
     def create_provider(self, name, config):
     def create_provider(self, name, config):
@@ -131,8 +131,7 @@ class CloudProviderFactory(object):
         :return:  a concrete provider instance
         :return:  a concrete provider instance
         :rtype: ``object`` of :class:`.CloudProvider`
         :rtype: ``object`` of :class:`.CloudProvider`
         """
         """
-        log.info("Searching provider with the name %s on %s",
-                 name, config)
+        log.info("Creating '%s' provider", name)
         provider_class = self.get_provider_class(name)
         provider_class = self.get_provider_class(name)
         if provider_class is None:
         if provider_class is None:
             log.exception("A provider with the name %s could not "
             log.exception("A provider with the name %s could not "
@@ -140,8 +139,7 @@ class CloudProviderFactory(object):
             raise NotImplementedError(
             raise NotImplementedError(
                 'A provider with name {0} could not be'
                 'A provider with name {0} could not be'
                 ' found'.format(name))
                 ' found'.format(name))
-        log.debug("Found provider name: %s with these config "
-                  " details: %s", name, config)
+        log.debug("Created '%s' provider", name)
         return provider_class(config)
         return provider_class(config)
 
 
     def get_provider_class(self, name, get_mock=False):
     def get_provider_class(self, name, get_mock=False):

+ 2 - 0
cloudbridge/cloud/providers/gce/provider.py

@@ -311,6 +311,8 @@ class GCECloudProvider(BaseCloudProvider):
         return out if out else self._storage_resources.parse_url(url)
         return out if out else self._storage_resources.parse_url(url)
 
 
     def get_resource(self, resource, url_or_name, **kwargs):
     def get_resource(self, resource, url_or_name, **kwargs):
+        if not url_or_name:
+            return None
         resource_url = (
         resource_url = (
             self._compute_resources.get_resource_url_with_default(
             self._compute_resources.get_resource_url_with_default(
                 resource, url_or_name, **kwargs) or
                 resource, url_or_name, **kwargs) or

+ 150 - 118
cloudbridge/cloud/providers/gce/resources.py

@@ -4,7 +4,6 @@ DataTypes used by this provider
 import hashlib
 import hashlib
 import inspect
 import inspect
 import io
 import io
-import json
 import math
 import math
 import uuid
 import uuid
 
 
@@ -30,6 +29,7 @@ from cloudbridge.cloud.base.resources import BaseSnapshot
 from cloudbridge.cloud.base.resources import BaseSubnet
 from cloudbridge.cloud.base.resources import BaseSubnet
 from cloudbridge.cloud.base.resources import BaseVMFirewall
 from cloudbridge.cloud.base.resources import BaseVMFirewall
 from cloudbridge.cloud.base.resources import BaseVMFirewallRule
 from cloudbridge.cloud.base.resources import BaseVMFirewallRule
+from cloudbridge.cloud.base.resources import BaseVMFirewallRuleContainer
 from cloudbridge.cloud.base.resources import BaseVMType
 from cloudbridge.cloud.base.resources import BaseVMType
 from cloudbridge.cloud.base.resources import BaseVolume
 from cloudbridge.cloud.base.resources import BaseVolume
 from cloudbridge.cloud.base.resources import ClientPagedResultList
 from cloudbridge.cloud.base.resources import ClientPagedResultList
@@ -41,6 +41,7 @@ from cloudbridge.cloud.interfaces.resources import NetworkState
 from cloudbridge.cloud.interfaces.resources import RouterState
 from cloudbridge.cloud.interfaces.resources import RouterState
 from cloudbridge.cloud.interfaces.resources import SnapshotState
 from cloudbridge.cloud.interfaces.resources import SnapshotState
 from cloudbridge.cloud.interfaces.resources import SubnetState
 from cloudbridge.cloud.interfaces.resources import SubnetState
+from cloudbridge.cloud.interfaces.resources import TrafficDirection
 from cloudbridge.cloud.interfaces.resources import VolumeState
 from cloudbridge.cloud.interfaces.resources import VolumeState
 from cloudbridge.cloud.providers.gce import helpers
 from cloudbridge.cloud.providers.gce import helpers
 
 
@@ -269,31 +270,42 @@ class GCEFirewallsDelegate(object):
             self._delete_firewall(firewall)
             self._delete_firewall(firewall)
         self._update_list_response()
         self._update_list_response()
 
 
-    def add_firewall(self, tag, ip_protocol, port, source_range, source_tag,
-                     description, network_name):
+    def add_firewall(self, tag, direction, protocol, port, src_dest_range,
+                     src_dest_tag, description, network_name):
         """
         """
         Create a new firewall.
         Create a new firewall.
         """
         """
-        if self.find_firewall(tag, ip_protocol, port, source_range,
-                              source_tag, network_name) is not None:
+        if self.find_firewall(
+                tag, direction, protocol, port, src_dest_range, src_dest_tag,
+                network_name) is not None:
             return True
             return True
         # Do not let the user accidentally open traffic from the world by not
         # Do not let the user accidentally open traffic from the world by not
         # explicitly specifying the source.
         # explicitly specifying the source.
-        if source_tag is None and source_range is None:
+        if src_dest_tag is None and src_dest_range is None:
             return False
             return False
         firewall = {
         firewall = {
             'name': 'firewall-{0}'.format(uuid.uuid4()),
             'name': 'firewall-{0}'.format(uuid.uuid4()),
             'network': GCEFirewallsDelegate._NETWORK_URL_PREFIX + network_name,
             'network': GCEFirewallsDelegate._NETWORK_URL_PREFIX + network_name,
-            'allowed': [{'IPProtocol': str(ip_protocol)}],
+            'allowed': [{'IPProtocol': str(protocol)}],
             'targetTags': [tag]}
             'targetTags': [tag]}
         if description is not None:
         if description is not None:
             firewall['description'] = description
             firewall['description'] = description
         if port is not None:
         if port is not None:
             firewall['allowed'][0]['ports'] = [port]
             firewall['allowed'][0]['ports'] = [port]
-        if source_range is not None:
-            firewall['sourceRanges'] = [source_range]
-        if source_tag is not None:
-            firewall['sourceTags'] = [source_tag]
+        if direction == TrafficDirection.INBOUND:
+            firewall['direction'] = 'INGRESS'
+            src_dest_str = 'source'
+        else:
+            firewall['direction'] = 'EGRESS'
+            src_dest_str = 'destination'
+        if src_dest_range is not None:
+            firewall[src_dest_str + 'Ranges'] = [src_dest_range]
+        if src_dest_tag is not None:
+            if direction == TrafficDirection.OUTBOUND:
+                cb.log.warning('GCP does not support egress rules to network '
+                               'tags. Only IP ranges are acceptable.')
+            else:
+                firewall['sourceTags'] = [src_dest_tag]
         project_name = self._provider.project_name
         project_name = self._provider.project_name
         try:
         try:
             response = (self._provider
             response = (self._provider
@@ -311,24 +323,28 @@ class GCEFirewallsDelegate(object):
             self._update_list_response()
             self._update_list_response()
         return True
         return True
 
 
-    def find_firewall(self, tag, ip_protocol, port, source_range, source_tag,
-                      network_name):
+    def find_firewall(self, tag, direction, protocol, port, src_dest_range,
+                      src_dest_tag, network_name):
         """
         """
         Find a firewall with give parameters.
         Find a firewall with give parameters.
         """
         """
-        if source_range is None and source_tag is None:
-            source_range = '0.0.0.0/0'
+        if src_dest_range is None and src_dest_tag is None:
+            src_dest_range = '0.0.0.0/0'
+        if direction == TrafficDirection.INBOUND:
+            src_dest_str = 'source'
+        else:
+            src_dest_str = 'destination'
         for firewall in self.iter_firewalls(tag, network_name):
         for firewall in self.iter_firewalls(tag, network_name):
-            if firewall['allowed'][0]['IPProtocol'] != ip_protocol:
+            if firewall['allowed'][0]['IPProtocol'] != protocol:
                 continue
                 continue
             if not self._check_list_in_dict(firewall['allowed'][0], 'ports',
             if not self._check_list_in_dict(firewall['allowed'][0], 'ports',
                                             port):
                                             port):
                 continue
                 continue
-            if not self._check_list_in_dict(firewall, 'sourceRanges',
-                                            source_range):
+            if not self._check_list_in_dict(firewall, src_dest_str + 'Ranges',
+                                            src_dest_range):
                 continue
                 continue
-            if not self._check_list_in_dict(firewall, 'sourceTags',
-                                            source_tag):
+            if not self._check_list_in_dict(firewall, src_dest_str + 'Tags',
+                                            src_dest_tag):
                 continue
                 continue
             return firewall['id']
             return firewall['id']
         return None
         return None
@@ -343,17 +359,22 @@ class GCEFirewallsDelegate(object):
                 continue
                 continue
             if ('sourceRanges' in firewall and
             if ('sourceRanges' in firewall and
                     len(firewall['sourceRanges']) == 1):
                     len(firewall['sourceRanges']) == 1):
-                info['source_range'] = firewall['sourceRanges'][0]
+                info['src_dest_range'] = firewall['sourceRanges'][0]
+            elif ('destinationRanges' in firewall and
+                    len(firewall['destinationRanges']) == 1):
+                info['src_dest_range'] = firewall['destinationRanges'][0]
             if 'sourceTags' in firewall and len(firewall['sourceTags']) == 1:
             if 'sourceTags' in firewall and len(firewall['sourceTags']) == 1:
-                info['source_tag'] = firewall['sourceTags'][0]
+                info['src_dest_tag'] = firewall['sourceTags'][0]
             if 'targetTags' in firewall and len(firewall['targetTags']) == 1:
             if 'targetTags' in firewall and len(firewall['targetTags']) == 1:
                 info['target_tag'] = firewall['targetTags'][0]
                 info['target_tag'] = firewall['targetTags'][0]
             if 'IPProtocol' in firewall['allowed'][0]:
             if 'IPProtocol' in firewall['allowed'][0]:
-                info['ip_protocol'] = firewall['allowed'][0]['IPProtocol']
+                info['protocol'] = firewall['allowed'][0]['IPProtocol']
             if ('ports' in firewall['allowed'][0] and
             if ('ports' in firewall['allowed'][0] and
                     len(firewall['allowed'][0]['ports']) == 1):
                     len(firewall['allowed'][0]['ports']) == 1):
                 info['port'] = firewall['allowed'][0]['ports'][0]
                 info['port'] = firewall['allowed'][0]['ports'][0]
             info['network_name'] = self.network_name(firewall)
             info['network_name'] = self.network_name(firewall)
+            if 'direction' in firewall:
+                info['direction'] = firewall['direction']
             return info
             return info
         return info
         return info
 
 
@@ -431,13 +452,14 @@ class GCEVMFirewall(BaseVMFirewall):
 
 
     def __init__(self, delegate, tag, network=None, description=None):
     def __init__(self, delegate, tag, network=None, description=None):
         super(GCEVMFirewall, self).__init__(delegate.provider, tag)
         super(GCEVMFirewall, self).__init__(delegate.provider, tag)
-        self._description = description
         self._delegate = delegate
         self._delegate = delegate
+        self._description = description
         if network is None:
         if network is None:
             self._network = delegate.provider.networking.networks.get_by_name(
             self._network = delegate.provider.networking.networks.get_by_name(
                     GCEFirewallsDelegate.DEFAULT_NETWORK)
                     GCEFirewallsDelegate.DEFAULT_NETWORK)
         else:
         else:
             self._network = network
             self._network = network
+        self._rule_container = GCEVMFirewallRuleContainer(self)
 
 
     @property
     @property
     def id(self):
     def id(self):
@@ -466,13 +488,14 @@ class GCEVMFirewall(BaseVMFirewall):
         If the GCE firewalls are created using this API, they all have the same
         If the GCE firewalls are created using this API, they all have the same
         description.
         description.
         """
         """
-        if self._description is not None:
-            return self._description
-        for firewall in self._delegate.iter_firewalls(self._vm_firewall,
-                                                      self._network.name):
-            if 'description' in firewall:
-                return firewall['description']
-        return None
+        if self._description is None:
+            for firewall in self._delegate.iter_firewalls(self._vm_firewall,
+                                                          self._network.name):
+                if 'description' in firewall:
+                    self._description = firewall['description']
+        if self._description is None:
+            self._description = ''
+        return self._description
 
 
     @property
     @property
     def network_id(self):
     def network_id(self):
@@ -480,11 +503,41 @@ class GCEVMFirewall(BaseVMFirewall):
 
 
     @property
     @property
     def rules(self):
     def rules(self):
-        out = []
-        for firewall in self._delegate.iter_firewalls(self._vm_firewall,
-                                                      self._network.name):
-            out.append(GCEVMFirewallRule(self._delegate, firewall['id']))
-        return out
+        return self._rule_container
+
+    def delete(self):
+        for rule in self._rule_container:
+            rule.delete()
+
+    def to_json(self):
+        attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
+        js = {k: v for(k, v) in attr if not k.startswith('_')}
+        json_rules = [r.to_json() for r in self.rules]
+        js['rules'] = json_rules
+        return js
+
+    @property
+    def network(self):
+        return self._network
+
+    @property
+    def delegate(self):
+        return self._delegate
+
+
+class GCEVMFirewallRuleContainer(BaseVMFirewallRuleContainer):
+
+    def __init__(self, firewall):
+        super(GCEVMFirewallRuleContainer, self).__init__(
+                firewall.delegate.provider, firewall)
+
+    def list(self, limit=None, marker=None):
+        rules = []
+        for firewall in self.firewall.delegate.iter_firewalls(
+                self.firewall.name, self.firewall.network.name):
+            rules.append(GCEVMFirewallRule(self.firewall, firewall['id']))
+        return ClientPagedResultList(self._provider, rules,
+                                     limit=limit, marker=marker)
 
 
     @staticmethod
     @staticmethod
     def to_port_range(from_port, to_port):
     def to_port_range(from_port, to_port):
@@ -495,75 +548,55 @@ class GCEVMFirewall(BaseVMFirewall):
         else:
         else:
             return to_port
             return to_port
 
 
-    def add_rule(self, ip_protocol, from_port=None, to_port=None,
-                 cidr_ip=None, src_group=None):
-        port = GCEVMFirewall.to_port_range(from_port, to_port)
-        src_tag = src_group.name if src_group is not None else None
-        self._delegate.add_firewall(self._vm_firewall, ip_protocol, port,
-                                    cidr_ip, src_tag, self.description,
-                                    self._network.name)
-        return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
-                             src_group)
-
-    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
-                 cidr_ip=None, src_group=None):
-        port = GCEVMFirewall.to_port_range(from_port, to_port)
-        src_tag = src_group.name if src_group is not None else None
-        firewall_id = self._delegate.find_firewall(
-                self._vm_firewall, ip_protocol, port, cidr_ip, src_tag,
-                self._network.name)
-        if firewall_id is None:
+    def create(self, direction, protocol, from_port=None, to_port=None,
+               cidr=None, src_dest_fw=None):
+        port = GCEVMFirewallRuleContainer.to_port_range(from_port, to_port)
+        src_dest_tag = None
+        src_dest_fw_id = None
+        if src_dest_fw:
+            src_dest_tag = src_dest_fw.name
+            src_dest_fw_id = src_dest_fw.id
+        if not self.firewall.delegate.add_firewall(
+                self.firewall.name, direction, protocol, port, cidr,
+                src_dest_tag, self.firewall.description,
+                self.firewall.network.name):
             return None
             return None
-        return GCEVMFirewallRule(self._delegate, firewall_id)
-
-    def to_json(self):
-        attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
-        js = {k: v for(k, v) in attr if not k.startswith('_')}
-        json_rules = [r.to_json() for r in self.rules]
-        js['rules'] = [json.loads(r) for r in json_rules]
-        return json.dumps(js, sort_keys=True)
-
-    def delete(self):
-        for rule in self.rules:
-            rule.delete()
+        rules = self.find(direction=direction, protocol=protocol,
+                          from_port=from_port, to_port=to_port, cidr=cidr,
+                          src_dest_fw_id=src_dest_fw_id)
+        if len(rules) < 1:
+            return None
+        return rules[0]
 
 
 
 
 class GCEVMFirewallRule(BaseVMFirewallRule):
 class GCEVMFirewallRule(BaseVMFirewallRule):
 
 
-    def __init__(self, delegate, firewall_id):
-        super(GCEVMFirewallRule, self).__init__(
-                delegate.provider, firewall_id, None)
-        self._delegate = delegate
-
-    @property
-    def parent(self):
-        """
-        Return the VM firewall to which this rule belongs.
-        """
-        info = self._delegate.get_firewall_info(self._rule)
-        if info is None:
-            return None
-        if 'target_tag' not in info or info['network_name'] is None:
-            return None
-        network = self._delegate.network.get_by_name(info['network_name'])
-        if network is None:
-            return None
-        return GCEVMFirewall(self._delegate, info['target_tag'], network)
+    def __init__(self, parent_fw, rule):
+        super(GCEVMFirewallRule, self).__init__(parent_fw, rule)
 
 
     @property
     @property
     def id(self):
     def id(self):
         return self._rule
         return self._rule
 
 
     @property
     @property
-    def ip_protocol(self):
-        info = self._delegate.get_firewall_info(self._rule)
-        if info is None or 'ip_protocol' not in info:
+    def direction(self):
+        info = self.firewall.delegate.get_firewall_info(self._rule)
+        if info is None:
+            return None
+        if 'direction' in info and info['direction'] == 'EGRESS':
+            return TrafficDirection.OUTBOUND
+        return TrafficDirection.INBOUND
+
+    @property
+    def protocol(self):
+        info = self.firewall.delegate.get_firewall_info(self._rule)
+        if info is None or 'protocol' not in info:
             return None
             return None
-        return info['ip_protocol']
+        return info['protocol']
 
 
     @property
     @property
     def from_port(self):
     def from_port(self):
-        info = self._delegate.get_firewall_info(self._rule)
+        info = self.firewall.delegate.get_firewall_info(self._rule)
         if info is None or 'port' not in info:
         if info is None or 'port' not in info:
             return 0
             return 0
         port = info['port']
         port = info['port']
@@ -578,7 +611,7 @@ class GCEVMFirewallRule(BaseVMFirewallRule):
 
 
     @property
     @property
     def to_port(self):
     def to_port(self):
-        info = self._delegate.get_firewall_info(self._rule)
+        info = self.firewall.delegate.get_firewall_info(self._rule)
         if info is None or 'port' not in info:
         if info is None or 'port' not in info:
             return 0
             return 0
         port = info['port']
         port = info['port']
@@ -592,40 +625,37 @@ class GCEVMFirewallRule(BaseVMFirewallRule):
         return 0
         return 0
 
 
     @property
     @property
-    def cidr_ip(self):
+    def cidr(self):
+        info = self.firewall.delegate.get_firewall_info(self._rule)
+        if info is None or 'src_dest_range' not in info:
+            return None
+        return info['src_dest_range']
+
+    @property
+    def src_dest_fw_id(self):
         """
         """
-        Return the IP of machines from which this rule allows traffic.
+        Return the VM firewall given access by this rule.
         """
         """
-        info = self._delegate.get_firewall_info(self._rule)
-        if info is None or 'source_range' not in info:
+        info = self.firewall.delegate.get_firewall_info(self._rule)
+        if info is None or 'src_dest_tag' not in info:
             return None
             return None
-        return info['source_range']
+        return GCEFirewallsDelegate.tag_network_id(info['src_dest_tag'],
+                                                   self.firewall.network.name)
 
 
     @property
     @property
-    def group(self):
+    def src_dest_fw(self):
         """
         """
-        Return the VM firewall from which this rule allows traffic.
+        Return the VM firewall given access by this rule.
         """
         """
-        info = self._delegate.get_firewall_info(self._rule)
-        if info is None:
+        info = self.firewall.delegate.get_firewall_info(self._rule)
+        if info is None or 'src_dest_tag' not in info:
             return None
             return None
-        if 'source_tag' not in info or info['network_name'] is None:
-            return None
-        network = self._delegate.provider.networking.networks.get_by_name(
-                info['network_name'])
-        if network is None:
-            return None
-        return GCEVMFirewall(self._delegate, info['source_tag'], network)
-
-    def to_json(self):
-        attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
-        js = {k: v for(k, v) in attr if not k.startswith('_')}
-        js['group'] = self.group.id if self.group else ''
-        js['parent'] = self.parent.id if self.parent else ''
-        return json.dumps(js, sort_keys=True)
+        return GCEVMFirewall(
+                self.firewall.delegate, info['src_dest_tag'],
+                self.firewall.network)
 
 
     def delete(self):
     def delete(self):
-        self._delegate.delete_firewall_id(self._rule)
+        self.firewall.delegate.delete_firewall_id(self._rule)
 
 
 
 
 class GCEMachineImage(BaseMachineImage):
 class GCEMachineImage(BaseMachineImage):
@@ -763,7 +793,7 @@ class GCEInstance(BaseInstance):
         # initially creating the resource. The name cannot be changed after
         # initially creating the resource. The name cannot be changed after
         # the instance is created.
         # the instance is created.
         cb.log.warning("Setting instance name after it is created is not "
         cb.log.warning("Setting instance name after it is created is not "
-                       "supported by this provider.")
+                       "supported by this provider: %s", value)
 
 
     @property
     @property
     def public_ips(self):
     def public_ips(self):
@@ -855,6 +885,7 @@ class GCEInstance(BaseInstance):
                                 instance=self.name)
                                 instance=self.name)
                         .execute())
                         .execute())
         self._provider.wait_for_operation(response, zone=self.zone_name)
         self._provider.wait_for_operation(response, zone=self.zone_name)
+        self._gce_instance = {'status': 'UNKNOWN'}
 
 
     def stop(self):
     def stop(self):
         """
         """
@@ -1176,9 +1207,10 @@ class GCEInstance(BaseInstance):
         Refreshes the state of this instance by re-querying the cloud provider
         Refreshes the state of this instance by re-querying the cloud provider
         for its latest state.
         for its latest state.
         """
         """
+        name = self.name
         self._gce_instance = self._provider.get_resource('instances', self.id)
         self._gce_instance = self._provider.get_resource('instances', self.id)
         if not self._gce_instance:
         if not self._gce_instance:
-            self._gce_instance = {'status': 'UNKNOWN'}
+            self._gce_instance = {'name': name, 'status': 'UNKNOWN'}
 
 
     def add_vm_firewall(self, sg):
     def add_vm_firewall(self, sg):
         tag = sg.name if isinstance(sg, GCEVMFirewall) else sg
         tag = sg.name if isinstance(sg, GCEVMFirewall) else sg
@@ -1594,7 +1626,7 @@ class GCESubnet(BaseSubnet):
 
 
     @property
     @property
     def region_name(self):
     def region_name(self):
-        parsed_url = self.provider.parse_url(self.id)
+        parsed_url = self._provider.parse_url(self.id)
         return parsed_url.parameters['region']
         return parsed_url.parameters['region']
 
 
     @property
     @property

+ 5 - 3
cloudbridge/cloud/providers/gce/services.py

@@ -874,15 +874,17 @@ class GCESubnetService(BaseSubnetService):
         instead of creating a new subnet. In this case, other parameters, i.e.
         instead of creating a new subnet. In this case, other parameters, i.e.
         the name and the zone, are ignored.
         the name and the zone, are ignored.
         """
         """
+        if not name:
+            name = 'subnet-{0}'.format(uuid.uuid4())
         GCESubnet.assert_valid_resource_name(name)
         GCESubnet.assert_valid_resource_name(name)
+        region_name = self._zone_to_region_name(zone)
         subnets = self.list(network)
         subnets = self.list(network)
         for subnet in subnets:
         for subnet in subnets:
             if BaseNetwork.cidr_blocks_overlap(subnet.cidr_block, cidr_block):
             if BaseNetwork.cidr_blocks_overlap(subnet.cidr_block, cidr_block):
                 return subnet
                 return subnet
+            if subnet.name == name and subnet.region_name == region_name:
+                return subnet
 
 
-        if not name:
-            name = 'subnet-{0}'.format(uuid.uuid4())
-        region_name = self._zone_to_region_name(zone)
         body = {'ipCidrRange': cidr_block,
         body = {'ipCidrRange': cidr_block,
                 'name': name,
                 'name': name,
                 'network': network.resource_url,
                 'network': network.resource_url,

+ 3 - 0
test/test_compute_service.py

@@ -9,6 +9,7 @@ from cloudbridge.cloud.interfaces import InvalidConfigurationException
 from cloudbridge.cloud.interfaces.exceptions import WaitStateException
 from cloudbridge.cloud.interfaces.exceptions import WaitStateException
 from cloudbridge.cloud.interfaces.resources import Instance
 from cloudbridge.cloud.interfaces.resources import Instance
 from cloudbridge.cloud.interfaces.resources import SnapshotState
 from cloudbridge.cloud.interfaces.resources import SnapshotState
+from cloudbridge.cloud.interfaces.resources import TrafficDirection
 from cloudbridge.cloud.interfaces.resources import VMType
 from cloudbridge.cloud.interfaces.resources import VMType
 
 
 import six
 import six
@@ -318,6 +319,8 @@ class CloudComputeServiceTestCase(ProviderTestBase):
                                                   subnet=subnet)
                                                   subnet=subnet)
             fw = self.provider.security.vm_firewalls.create(
             fw = self.provider.security.vm_firewalls.create(
                 name=name, description=name, network_id=net.id)
                 name=name, description=name, network_id=net.id)
+            fw.rules.create(direction=TrafficDirection.INBOUND, protocol='tcp',
+                            from_port=1111, to_port=1111, cidr='0.0.0.0/0')
 
 
             # Check adding a VM firewall to a running instance
             # Check adding a VM firewall to a running instance
             test_inst.add_vm_firewall(fw)
             test_inst.add_vm_firewall(fw)