Sfoglia il codice sorgente

Resolve security group issues after the merge

Ehsan Chiniforooshan 9 anni fa
parent
commit
09a9da6b47

+ 2 - 0
cloudbridge/cloud/base/services.py

@@ -146,6 +146,8 @@ class BaseNetworkService(
         super(BaseNetworkService, self).__init__(provider)
 
     def delete(self, network_id):
+        if network_id is None:
+            return True
         network = self.get(network_id)
         if network:
             network.delete()

+ 56 - 51
cloudbridge/cloud/providers/gce/resources.py

@@ -21,6 +21,7 @@ import hashlib
 import inspect
 import json
 import re
+import uuid
 
 class GCEKeyPair(BaseKeyPair):
 
@@ -188,16 +189,16 @@ class GCEFirewallsDelegate(object):
         self._list_response = None
 
     @staticmethod
-    def tag_network_id(tag, network):
+    def tag_network_id(tag, network_name):
         """
-        Generate an ID for a (tag, network) pair.
+        Generate an ID for a (tag, network name) pair.
         """
         md5 = hashlib.md5()
-        md5.update("{0}-{1}".format(tag, network).encode('ascii'))
+        md5.update("{0}-{1}".format(tag, network_name).encode('ascii'))
         return md5.hexdigest()
 
     @staticmethod
-    def network(firewall):
+    def network_name(firewall):
         """
         Extract the network name of a firewall.
         """
@@ -217,60 +218,51 @@ class GCEFirewallsDelegate(object):
     @property
     def tag_networks(self):
         """
-        List all (tag, network) pairs that are used in at least one firewall.
+        List all (tag, network name) pairs that are in at least one firewall.
         """
         out = set()
         for firewall in self.iter_firewalls():
-            network = GCEFirewallsDelegate.network(firewall)
-            if network is not None:
-                out.add((firewall['targetTags'][0], network))
+            network_name = GCEFirewallsDelegate.network_name(firewall)
+            if network_name is not None:
+                out.add((firewall['targetTags'][0], network_name))
         return out
             
     def get_tag_network_from_id(self, tag_network_id):
         """
-        Map an ID back to the (tag, network) pair.
+        Map an ID back to the (tag, network name) pair.
         """
-        for tag, network in self.tag_networks:
-            current_id = GCEFirewallsDelegate.tag_network_id(tag, network)
+        for tag, network_name in self.tag_networks:
+            current_id = GCEFirewallsDelegate.tag_network_id(tag, network_name)
             if current_id == tag_network_id:
-                return (tag, network)
+                return (tag, network_name)
         return (None, None)
 
     def delete_tag_network_with_id(self, tag_network_id):
         """
         Delete all firewalls in a given network with a specific target tag.
         """
-        tag, network = self.get_tag_network_from_id(tag_network_id)
+        tag, network_name = self.get_tag_network_from_id(tag_network_id)
         if tag is None:
             return
-        for firewall in self.iter_firewalls(tag, network):
+        for firewall in self.iter_firewalls(tag, network_name):
             self._delete_firewall(firewall)
         self._update_list_response()
 
     def add_firewall(self, tag, ip_protocol, port, source_range, source_tag,
-                     description, network):
+                     description, network_name):
         """
         Create a new firewall.
         """
         if self.find_firewall(tag, ip_protocol, port, source_range,
-                              source_tag, network) is not None:
+                              source_tag, network_name) is not None:
             return True
         # Do not let the user accidentally open traffic from the world by not
         # explicitly specifying the source.
         if source_tag is None and source_range is None:
             return False
-        firewall_number = 1
-        suffixes = []
-        for firewall in self.iter_firewalls(tag, network):
-            suffix = firewall['name'].split('-')[-1]
-            if suffix.isdigit():
-                suffixes.append(int(suffix))
-        for suffix in sorted(suffixes):
-            if firewall_number == suffix:
-                firewall_number += 1
         firewall = {
-            'name': '%s-%s-rule-%d' % (network, tag, firewall_number),
-            'network': GCEFirewallsDelegate._NETWORK_URL_PREFIX + network,
+            'name': 'firewall-{0}'.format(uuid.uuid4()),
+            'network': GCEFirewallsDelegate._NETWORK_URL_PREFIX + network_name,
             'allowed': [{'IPProtocol': str(ip_protocol)}],
             'targetTags': [tag]}
         if description is not None:
@@ -297,13 +289,13 @@ class GCEFirewallsDelegate(object):
             self._update_list_response()
 
     def find_firewall(self, tag, ip_protocol, port, source_range, source_tag,
-                      network):
+                      network_name):
         """
         Find a firewall with give parameters.
         """
         if source_range is None and source_tag is None:
             source_range = '0.0.0.0/0'
-        for firewall in self.iter_firewalls(tag, network):
+        for firewall in self.iter_firewalls(tag, network_name):
             if firewall['allowed'][0]['IPProtocol'] != ip_protocol:
                 continue
             if not self._check_list_in_dict(firewall['allowed'][0], 'ports',
@@ -337,7 +329,7 @@ class GCEFirewallsDelegate(object):
             if ('ports' in firewall['allowed'][0] and
                 len(firewall['allowed'][0]['ports']) == 1):
                 info['port'] = firewall['allowed'][0]['ports'][0]
-            info['network'] = GCEFirewallsDelegate.network(firewall)
+            info['network_name'] = GCEFirewallsDelegate.network_name(firewall)
             return info
         return info
 
@@ -350,7 +342,7 @@ class GCEFirewallsDelegate(object):
                 self._delete_firewall(firewall)
         self._update_list_response()
 
-    def iter_firewalls(self, tag=None, network=None):
+    def iter_firewalls(self, tag=None, network_name=None):
         """
         Iterate through all firewalls. Can optionally iterate through firewalls
         with a given tag and/or in a network.
@@ -366,11 +358,11 @@ class GCEFirewallsDelegate(object):
                 continue
             if tag is not None and firewall['targetTags'][0] != tag:
                 continue
-            if network is None:
+            if network_name is None:
                 yield firewall
                 continue
-            firewall_network = GCEFirewallsDelegate.network(firewall)
-            if firewall_network == network:
+            firewall_network_name = GCEFirewallsDelegate.network_name(firewall)
+            if firewall_network_name == network_name:
                 yield firewall
 
     def _delete_firewall(self, firewall):
@@ -415,15 +407,15 @@ class GCEFirewallsDelegate(object):
 
 class GCESecurityGroup(BaseSecurityGroup):
 
-    def __init__(self, delegate, tag,
-                 network=GCEFirewallsDelegate.DEFAULT_NETWORK,
-                 description=None):
+    def __init__(self, delegate, tag, network=None, description=None):
         super(GCESecurityGroup, self).__init__(delegate.provider, tag)
         self._description = description
         self._delegate = delegate
-        self._network = network
-        if self._network is None:
-            self._network = GCEFirewallsDelegate.DEFAULT_NETWORK 
+        if network is None:
+            self._network = delegate.provider.network.get_by_name(
+                    GCEFirewallsDelegate.DEFAULT_NETWORK)
+        else:
+            self._network = network
 
     @property
     def id(self):
@@ -432,7 +424,7 @@ class GCESecurityGroup(BaseSecurityGroup):
         network and the target tag corresponding to this security group.
         """
         return GCEFirewallsDelegate.tag_network_id(self._security_group,
-                                                   self._network)
+                                                   self._network.name)
 
     @property
     def name(self):
@@ -454,16 +446,20 @@ class GCESecurityGroup(BaseSecurityGroup):
         if self._description is not None:
             return self._description
         for firewall in self._delegate.iter_firewalls(self._security_group,
-                                                      self._network):
+                                                      self._network.name):
             if 'description' in firewall:
                 return firewall['description']
         return None
 
+    @property
+    def network_id(self):
+        return self._network.id
+
     @property
     def rules(self):
         out = []
         for firewall in self._delegate.iter_firewalls(self._security_group,
-                                                      self._network):
+                                                      self._network.name):
             out.append(GCESecurityGroupRule(self._delegate, firewall['id']))
         return out
 
@@ -482,7 +478,7 @@ class GCESecurityGroup(BaseSecurityGroup):
         src_tag = src_group.name if src_group is not None else None
         self._delegate.add_firewall(self._security_group, ip_protocol, port,
                                     cidr_ip, src_tag, self.description,
-                                    self._network)
+                                    self._network.name)
         return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
                              src_group)
 
@@ -492,7 +488,7 @@ class GCESecurityGroup(BaseSecurityGroup):
         src_tag = src_group.name if src_group is not None else None
         firewall_id = self._delegate.find_firewall(
                 self._security_group, ip_protocol, port, cidr_ip, src_tag,
-                self._network)
+                self._network.name)
         if firewall_id is None:
             return None
         return GCESecurityGroupRule(self._delegate, firewall_id)
@@ -522,10 +518,14 @@ class GCESecurityGroupRule(BaseSecurityGroupRule):
         Return the security group to which this rule belongs.
         """
         info = self._delegate.get_firewall_info(self._rule)
-        if info is None or 'target_tag' not in info or info['network'] is None:
+        if info is None:
             return None
-        return GCESecurityGroup(self._delegate, info['target_tag'],
-                                info['network'])
+        if 'target_tag' not in info or info['network_name'] is None:
+            return None
+        network = delegate.network.get_by_name(info['network_name'])
+        if network is None:
+            return None
+        return GCESecurityGroup(self._delegate, info['target_tag'], network)
 
     @property
     def id(self):
@@ -584,10 +584,15 @@ class GCESecurityGroupRule(BaseSecurityGroupRule):
         Return the security group from which this rule allows traffic.
         """
         info = self._delegate.get_firewall_info(self._rule)
-        if info is None or 'source_tag' not in info or info['network'] is None:
+        if info is None:
+            return None
+        if 'source_tag' not in info or info['network_name'] is None:
+            return None
+        network = self._delegate.provider.network.get_by_name(
+                info['network_name'])
+        if network is None:
             return None
-        return GCESecurityGroup(self._delegate, info['source_tag'],
-                                info['network'])
+        return GCESecurityGroup(self._delegate, info['source_tag'], network)
 
     def to_json(self):
         attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))

+ 32 - 7
cloudbridge/cloud/providers/gce/services.py

@@ -202,19 +202,24 @@ class GCESecurityGroupService(BaseSecurityGroupService):
         self._delegate = GCEFirewallsDelegate(provider)
 
     def get(self, group_id):
-        tag, network = self._delegate.get_tag_network_from_id(group_id)
+        tag, network_name = self._delegate.get_tag_network_from_id(group_id)
         if tag is None:
             return None
+        network = self.provider.network.get_by_name(network_name)
         return GCESecurityGroup(self._delegate, tag, network)
 
     def list(self, limit=None, marker=None):
-        security_groups = [GCESecurityGroup(self._delegate, x, y)
-                           for x, y in self._delegate.tag_networks]
+        security_groups = []
+        for tag, network_name in self._delegate.tag_networks:
+            network = self.provider.network.get_by_name(network_name)
+            security_group = GCESecurityGroup(self._delegate, tag, network)
+            security_groups.append(security_group)
         return ClientPagedResultList(self.provider, security_groups,
                                      limit=limit, marker=marker)
 
     def create(self, name, description, network_id=None):
-        return GCESecurityGroup(self._delegate, name, network_id, description)
+        network = self.provider.network.get(network_id)
+        return GCESecurityGroup(self._delegate, name, network, description)
 
     def find(self, name, limit=None, marker=None):
         """
@@ -223,8 +228,9 @@ class GCESecurityGroupService(BaseSecurityGroupService):
         is returned.
         """
         out = []
-        for tag, network in self._delegate.tag_networks:
+        for tag, network_name in self._delegate.tag_networks:
             if tag == name:
+                network = self.provider.network.get_by_name(network_name)
                 out.append(GCESecurityGroup(self._delegate, name, network))
         return out
 
@@ -427,7 +433,22 @@ class GCENetworkService(BaseNetworkService):
         super(GCENetworkService, self).__init__(provider)
 
     def get(self, network_id):
-        networks = self.list(filter='id eq %s' % network_id)
+        if network_id is None:
+            return None
+        # networks = self.list(filter='id eq %s' % network_id) would be better.
+        # But, there is a GCE API bug that causes an error if the network_id
+        # has more than 19 digits. So, we list all networks and filter
+        # ourselves.
+        networks = self.list()
+        for network in networks:
+            if network.id == network_id:
+                return network
+        return None
+
+    def get_by_name(self, network_name):
+        if network_name is None:
+            return None
+        networks = self.list(filter='name eq %s' % network_name)
         return None if len(networks) == 0 else networks[0]
 
     def list(self, limit=None, marker=None, filter=None):
@@ -447,6 +468,10 @@ class GCENetworkService(BaseNetworkService):
 
     def create(self, name):
         try:
+            networks = self.list(filter='name eq %s' % name)
+            if len(networks) > 0:
+                return networks[0]
+
             response = (self.provider.gce_compute
                                      .networks()
                                      .insert(project=self.provider.project_name,
@@ -457,7 +482,7 @@ class GCENetworkService(BaseNetworkService):
             self.provider.wait_for_global_operation(response)
             networks = self.list(filter='name eq %s' % name)
             return None if len(networks) == 0 else networks[0]
-        except Exception as e:
+        except:
             return None
 
     @property

+ 6 - 6
test/test_security_service.py

@@ -14,7 +14,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             methodName=methodName, provider=provider)
 
     def test_crud_key_pair_service(self):
-        name = 'cbtestkeypair-a'
+        name = 'cbtestkeypairA-{0}'.format(uuid.uuid4()).lower()
         kp = self.provider.security.key_pairs.create(name=name)
         with helpers.cleanup_action(
             lambda:
@@ -65,7 +65,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             "Found a key pair {0} that should not exist?".format(no_kp))
 
     def test_key_pair(self):
-        name = 'cbtestkeypair-b'
+        name = 'cbtestkeypairB-{0}'.format(uuid.uuid4()).lower()
         kp = self.provider.security.key_pairs.create(name=name)
         with helpers.cleanup_action(lambda: kp.delete()):
             kpl = self.provider.security.key_pairs.list()
@@ -101,7 +101,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             self.provider.security.security_groups.delete(group_id=sg.id)
 
     def test_crud_security_group_service(self):
-        name = 'cbtestsecuritygroup-a'
+        name = 'cbtestsecuritygroupA-{0}'.format(uuid.uuid4()).lower()
         net = self.provider.network.create(name=name)
         sg = self.provider.security.security_groups.create(
             name=name, description=name, network_id=net.id)
@@ -157,7 +157,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
 
     def test_security_group(self):
         """Test for proper creation of a security group."""
-        name = 'cbtestsecuritygroup-b'
+        name = 'cbtestsecuritygroupB-{0}'.format(uuid.uuid4()).lower()
         net = self.provider.network.create(name=name)
         sg = self.provider.security.security_groups.create(
             name=name, description=name, network_id=net.id)
@@ -207,7 +207,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
 
     def test_security_group_rule_add_twice(self):
         """Test whether adding the same rule twice succeeds."""
-        name = 'cbtestsecuritygroupB-{0}'.format(uuid.uuid4())
+        name = 'cbtestsecuritygroupB-{0}'.format(uuid.uuid4()).lower()
         net = self.provider.network.create(name=name)
         sg = self.provider.security.security_groups.create(
             name=name, description=name, network_id=net.id)
@@ -224,7 +224,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
 
     def test_security_group_group_rule(self):
         """Test for proper creation of a security group rule."""
-        name = 'cbtestsecuritygroup-c'
+        name = 'cbtestsecuritygroupC-{0}'.format(uuid.uuid4()).lower()
         net = self.provider.network.create(name=name)
         sg = self.provider.security.security_groups.create(
             name=name, description=name, network_id=net.id)