Browse Source

Security groups in custom networks

Ehsan Chiniforooshan 9 years ago
parent
commit
c15095ac45

+ 75 - 36
cloudbridge/cloud/providers/gce/resources.py

@@ -17,6 +17,7 @@ except NameError:
 import hashlib
 import hashlib
 import inspect
 import inspect
 import json
 import json
+import re
 
 
 class GCEKeyPair(BaseKeyPair):
 class GCEKeyPair(BaseKeyPair):
 
 
@@ -176,49 +177,61 @@ class GCERegion(BaseRegion):
 
 
 
 
 class GCEFirewallsDelegate(object):
 class GCEFirewallsDelegate(object):
+    NETWORK_URL_PREFIX = 'global/networks/'
+    DEFAULT_NETWORK = 'default'
   
   
     def __init__(self, provider):
     def __init__(self, provider):
         self._provider = provider
         self._provider = provider
         self._list_response = None
         self._list_response = None
 
 
     @staticmethod
     @staticmethod
-    def tag_id(tag):
+    def tagnet_id(tag, network):
         md5 = hashlib.md5()
         md5 = hashlib.md5()
-        md5.update(tag.encode('ascii'))
+        md5.update("{0}-{1}".format(tag, network).encode('ascii'))
         return md5.hexdigest()
         return md5.hexdigest()
 
 
+    @staticmethod
+    def network(firewall):
+        if 'network' not in firewall:
+            return GCEFirewallsDelegate.DEFAULT_NETWORK
+        match = re.search(
+                GCEFirewallsDelegate.NETWORK_URL_PREFIX + '([^/]*)$',
+                firewall['network'])
+        if match and len(match.groups()) == 1:
+            return match.group(1)
+        return None
+
     @property
     @property
     def provider(self):
     def provider(self):
         return self._provider
         return self._provider
 
 
     @property
     @property
-    def tags(self):
+    def tagnets(self):
         out = set()
         out = set()
         for firewall in self.iter_firewalls():
         for firewall in self.iter_firewalls():
-            out.add(firewall['targetTags'][0])
+            network = GCEFirewallsDelegate.network(firewall)
+            if network is not None:
+                out.add((firewall['targetTags'][0], network))
         return out
         return out
             
             
-    def get_tag_from_id(self, tag_id):
-        for tag in self.tags:
-            if GCEFirewallsDelegate.tag_id(tag) == tag_id:
-                return tag
-        return None
-
-    def has_tag(self, tag):
-        return tag in self.tags
-
-    def delete_tag_with_id(self, tag_id):
-        tag = self.get_tag_from_id(tag_id)
+    def get_tagnet_from_id(self, tagnet_id):
+        for tag, network in self.tagnets:
+            if GCEFirewallsDelegate.tagnet_id(tag, network) == tagnet_id:
+                return (tag, network)
+        return (None, None)
+
+    def delete_tagnet_with_id(self, tagnet_id):
+        tag, network = self.get_tagnet_from_id(tagnet_id)
         if tag is None:
         if tag is None:
             return
             return
-        for firewall in self.iter_firewalls(tag):
+        for firewall in self.iter_firewalls(tag, network):
             self._delete_firewall(firewall)
             self._delete_firewall(firewall)
         self._update_list_response()
         self._update_list_response()
 
 
     def add_firewall(self, tag, ip_protocol, port, source_range, source_tag,
     def add_firewall(self, tag, ip_protocol, port, source_range, source_tag,
-                     description):
+                     description, network):
         if self.find_firewall(tag, ip_protocol, port, source_range,
         if self.find_firewall(tag, ip_protocol, port, source_range,
-                              source_tag) is not None:
+                              source_tag, network) is not None:
             return True
             return True
         # Do not let the user accidentally open traffic from the world by not
         # Do not let the user accidentally open traffic from the world by not
         # explicitly specifying the source.
         # explicitly specifying the source.
@@ -226,16 +239,18 @@ class GCEFirewallsDelegate(object):
             return False
             return False
         firewall_number = 1
         firewall_number = 1
         suffixes = []
         suffixes = []
-        for firewall in self.iter_firewalls(tag):
+        for firewall in self.iter_firewalls(tag, network):
             suffix = firewall['name'].split('-')[-1]
             suffix = firewall['name'].split('-')[-1]
             if suffix.isdigit():
             if suffix.isdigit():
                 suffixes.append(int(suffix))
                 suffixes.append(int(suffix))
         for suffix in sorted(suffixes):
         for suffix in sorted(suffixes):
             if firewall_number == suffix:
             if firewall_number == suffix:
                 firewall_number += 1
                 firewall_number += 1
-        firewall = {'name': '%s-rule-%d' % (tag, firewall_number),
-                    'allowed': [{'IPProtocol': str(ip_protocol)}],
-                    'targetTags': [tag]}
+        firewall = {
+            'name': '%s-%s-rule-%d' % (network, tag, firewall_number),
+            'network': GCEFirewallsDelegate.NETWORK_URL_PREFIX + network,
+            'allowed': [{'IPProtocol': str(ip_protocol)}],
+            'targetTags': [tag]}
         if description is not None:
         if description is not None:
             firewall['description'] = description
             firewall['description'] = description
         if port is not None:
         if port is not None:
@@ -259,10 +274,11 @@ class GCEFirewallsDelegate(object):
         finally:
         finally:
             self._update_list_response()
             self._update_list_response()
 
 
-    def find_firewall(self, tag, ip_protocol, port, source_range, source_tag):
+    def find_firewall(self, tag, ip_protocol, port, source_range, source_tag,
+                      network):
         if source_range is None and source_tag is None:
         if source_range is None and source_tag is None:
             source_range = '0.0.0.0/0'
             source_range = '0.0.0.0/0'
-        for firewall in self.iter_firewalls(tag):
+        for firewall in self.iter_firewalls(tag, network):
             if firewall['allowed'][0]['IPProtocol'] != ip_protocol:
             if firewall['allowed'][0]['IPProtocol'] != ip_protocol:
                 continue
                 continue
             if not self._check_list_in_dict(firewall['allowed'][0], 'ports',
             if not self._check_list_in_dict(firewall['allowed'][0], 'ports',
@@ -293,6 +309,7 @@ class GCEFirewallsDelegate(object):
             if ('ports' in firewall['allowed'][0] and
             if ('ports' in firewall['allowed'][0] and
                 len(firewall['allowed'][0]['ports']) == 1):
                 len(firewall['allowed'][0]['ports']) == 1):
                 info['port'] = firewall['allowed'][0]['ports'][0]
                 info['port'] = firewall['allowed'][0]['ports'][0]
+            info['network'] = GCEFirewallsDelegate.network(firewall)
             return info
             return info
         return info
         return info
 
 
@@ -302,7 +319,7 @@ class GCEFirewallsDelegate(object):
                 self._delete_firewall(firewall)
                 self._delete_firewall(firewall)
         self._update_list_response()
         self._update_list_response()
 
 
-    def iter_firewalls(self, tag=None):
+    def iter_firewalls(self, tag=None, network=None):
         if self._list_response is None:
         if self._list_response is None:
             self._update_list_response()
             self._update_list_response()
         if 'items' not in self._list_response:
         if 'items' not in self._list_response:
@@ -312,7 +329,13 @@ class GCEFirewallsDelegate(object):
                 continue
                 continue
             if 'allowed' not in firewall or len(firewall['allowed']) != 1:
             if 'allowed' not in firewall or len(firewall['allowed']) != 1:
                 continue
                 continue
-            if tag is None or firewall['targetTags'][0] == tag:
+            if tag is not None and firewall['targetTags'][0] != tag:
+                continue
+            if network is None:
+                yield firewall
+                continue
+            firewall_network = GCEFirewallsDelegate.network(firewall)
+            if firewall_network == network:
                 yield firewall
                 yield firewall
 
 
     def _delete_firewall(self, firewall):
     def _delete_firewall(self, firewall):
@@ -348,14 +371,20 @@ class GCEFirewallsDelegate(object):
 
 
 class GCESecurityGroup(BaseSecurityGroup):
 class GCESecurityGroup(BaseSecurityGroup):
 
 
-    def __init__(self, delegate, tag, description=None):
+    def __init__(self, delegate, tag,
+                 network=GCEFirewallsDelegate.DEFAULT_NETWORK,
+                 description=None):
         super(GCESecurityGroup, self).__init__(delegate.provider, tag)
         super(GCESecurityGroup, self).__init__(delegate.provider, tag)
         self._description = description
         self._description = description
         self._delegate = delegate
         self._delegate = delegate
+        self._network = network
+        if self._network is None:
+            self._network = GCEFirewallsDelegate.DEFAULT_NETWORK 
 
 
     @property
     @property
     def id(self):
     def id(self):
-        return GCEFirewallsDelegate.tag_id(self._security_group)
+        return GCEFirewallsDelegate.tagnet_id(self._security_group,
+                                              self.network)
 
 
     @property
     @property
     def name(self):
     def name(self):
@@ -365,7 +394,8 @@ class GCESecurityGroup(BaseSecurityGroup):
     def description(self):
     def description(self):
         if self._description is not None:
         if self._description is not None:
             return self._description
             return self._description
-        for firewall in self._delegate.iter_firewalls(self._security_group):
+        for firewall in self._delegate.iter_firewalls(self._security_group,
+                                                      self.network):
             if 'description' in firewall:
             if 'description' in firewall:
                 return firewall['description']
                 return firewall['description']
         return None
         return None
@@ -373,10 +403,15 @@ class GCESecurityGroup(BaseSecurityGroup):
     @property
     @property
     def rules(self):
     def rules(self):
         out = []
         out = []
-        for firewall in self._delegate.iter_firewalls(self._security_group):
+        for firewall in self._delegate.iter_firewalls(self._security_group,
+                                                      self.network):
             out.append(GCESecurityGroupRule(self._delegate, firewall['id']))
             out.append(GCESecurityGroupRule(self._delegate, firewall['id']))
         return out
         return out
 
 
+    @property
+    def network(self):
+        return self._network
+
     @staticmethod
     @staticmethod
     def to_port_range(from_port, to_port):
     def to_port_range(from_port, to_port):
         if from_port is not None and to_port is not None:
         if from_port is not None and to_port is not None:
@@ -391,7 +426,8 @@ class GCESecurityGroup(BaseSecurityGroup):
         port = GCESecurityGroup.to_port_range(from_port, to_port)
         port = GCESecurityGroup.to_port_range(from_port, to_port)
         src_tag = src_group.name if src_group is not None else None
         src_tag = src_group.name if src_group is not None else None
         self._delegate.add_firewall(self._security_group, ip_protocol, port,
         self._delegate.add_firewall(self._security_group, ip_protocol, port,
-                                    cidr_ip, src_tag, self.description)
+                                    cidr_ip, src_tag, self.description,
+                                    self.network)
         return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
         return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
                              src_group)
                              src_group)
 
 
@@ -400,7 +436,8 @@ class GCESecurityGroup(BaseSecurityGroup):
         port = GCESecurityGroup.to_port_range(from_port, to_port)
         port = GCESecurityGroup.to_port_range(from_port, to_port)
         src_tag = src_group.name if src_group is not None else None
         src_tag = src_group.name if src_group is not None else None
         firewall_id = self._delegate.find_firewall(
         firewall_id = self._delegate.find_firewall(
-                self._security_group, ip_protocol, port, cidr_ip, src_tag)
+                self._security_group, ip_protocol, port, cidr_ip, src_tag,
+                self.network)
         if firewall_id is None:
         if firewall_id is None:
             return None
             return None
         return GCESecurityGroupRule(self._delegate, firewall_id)
         return GCESecurityGroupRule(self._delegate, firewall_id)
@@ -427,9 +464,10 @@ class GCESecurityGroupRule(BaseSecurityGroupRule):
     @property
     @property
     def parent(self):
     def parent(self):
         info = self._delegate.get_firewall_info(self._rule)
         info = self._delegate.get_firewall_info(self._rule)
-        if info is None or 'target_tag' not in info:
+        if info is None or 'target_tag' not in info or info['network'] is None:
             return None
             return None
-        return GCESecurityGroup(self._delegate, info['target_tag'])
+        return GCESecurityGroup(self._delegate, info['target_tag'],
+                                info['network'])
 
 
     @property
     @property
     def id(self):
     def id(self):
@@ -482,9 +520,10 @@ class GCESecurityGroupRule(BaseSecurityGroupRule):
     @property
     @property
     def group(self):
     def group(self):
         info = self._delegate.get_firewall_info(self._rule)
         info = self._delegate.get_firewall_info(self._rule)
-        if info is None or 'source_tag' not in info:
+        if info is None or 'source_tag' not in info or info['network'] is None:
             return None
             return None
-        return GCESecurityGroup(self._delegate, info['source_tag'])
+        return GCESecurityGroup(self._delegate, info['source_tag'],
+                                info['network'])
 
 
     def to_json(self):
     def to_json(self):
         attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
         attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))

+ 14 - 10
cloudbridge/cloud/providers/gce/services.py

@@ -196,17 +196,19 @@ class GCESecurityGroupService(BaseSecurityGroupService):
         self._delegate = GCEFirewallsDelegate(provider)
         self._delegate = GCEFirewallsDelegate(provider)
 
 
     def get(self, group_id):
     def get(self, group_id):
-        tag = self._delegate.get_tag_from_id(group_id)
-        return None if tag is None else GCESecurityGroup(self._delegate, tag)
+        tag, network = self._delegate.get_tagnet_from_id(group_id)
+        if tag is None:
+            return None
+        return GCESecurityGroup(self._delegate, tag, network)
 
 
     def list(self, limit=None, marker=None):
     def list(self, limit=None, marker=None):
-        security_groups = [GCESecurityGroup(self._delegate, x)
-                           for x in self._delegate.tags]
+        security_groups = [GCESecurityGroup(self._delegate, x, y)
+                           for x, y in self._delegate.tagnets]
         return ClientPagedResultList(self.provider, security_groups,
         return ClientPagedResultList(self.provider, security_groups,
                                      limit=limit, marker=marker)
                                      limit=limit, marker=marker)
 
 
-    def create(self, name, description):
-        return GCESecurityGroup(self._delegate, name, description)
+    def create(self, name, description, network_id=None):
+        return GCESecurityGroup(self._delegate, name, network_id, description)
 
 
     def find(self, name, limit=None, marker=None):
     def find(self, name, limit=None, marker=None):
         """
         """
@@ -214,12 +216,14 @@ class GCESecurityGroupService(BaseSecurityGroupService):
         name does not exist, or if it does not contain any rules, an empty list
         name does not exist, or if it does not contain any rules, an empty list
         is returned.
         is returned.
         """
         """
-        if self._delegate.has_tag(name):
-            return [GCESecurityGroup(self._delegate, name)]
-        return []
+        out = []
+        for tag, network in self._delegate.tagnets:
+            if tag == name:
+                out.append(GCESecurityGroup(self._delegate, name, network))
+        return out
 
 
     def delete(self, group_id):
     def delete(self, group_id):
-        return self._delegate.delete_tag_with_id(group_id)
+        return self._delegate.delete_tagnet_with_id(group_id)
 
 
 
 
 class GCEInstanceTypesService(BaseInstanceTypesService):
 class GCEInstanceTypesService(BaseInstanceTypesService):

+ 2 - 2
test/test_security_service.py

@@ -163,7 +163,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                                      to_port=1111, cidr_ip='0.0.0.0/0')
                                      to_port=1111, cidr_ip='0.0.0.0/0')
             self.assertTrue(
             self.assertTrue(
                 rule == found_rule,
                 rule == found_rule,
-                "Expected rule {0} not found in security group: {0}".format(
+                "Expected rule {0} not found in security group: {1}".format(
                     rule, sg.rules))
                     rule, sg.rules))
 
 
             object_keys = (
             object_keys = (
@@ -212,7 +212,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                                     to_port=1111, cidr_ip='0.0.0.0/0')
                                     to_port=1111, cidr_ip='0.0.0.0/0')
             self.assertTrue(
             self.assertTrue(
                 rule == same_rule,
                 rule == same_rule,
-                "Expected rule {0} not found in security group: {0}".format(
+                "Expected rule {0} not found in security group: {1}".format(
                     same_rule, sg.rules))
                     same_rule, sg.rules))
 
 
     def test_security_group_group_rule(self):
     def test_security_group_group_rule(self):