Просмотр исходного кода

Merge pull request #19 from chiniforooshan/gce_sg

GCE Security Groups Based on Firewalls
Enis Afgan 9 лет назад
Родитель
Сommit
e470a49d6a

+ 5 - 1
cloudbridge/cloud/base/resources.py

@@ -526,7 +526,7 @@ class BaseSecurityGroupRule(SecurityGroupRule, BaseCloudResource):
     def __init__(self, provider, rule, parent):
         super(BaseSecurityGroupRule, self).__init__(provider)
         self._rule = rule
-        self.parent = parent
+        self._parent = parent
 
     def __repr__(self):
         return ("<CBSecurityGroupRule: IP: {0}; from: {1}; to: {2}; grp: {3}>"
@@ -553,6 +553,10 @@ class BaseSecurityGroupRule(SecurityGroupRule, BaseCloudResource):
                                              self.to_port, self.cidr_ip,
                                              self.group))
 
+    @property
+    def parent(self):
+        return self._parent
+
 
 class BasePlacementZone(PlacementZone, BaseCloudResource):
 

+ 37 - 0
cloudbridge/cloud/providers/gce/README.rst

@@ -0,0 +1,37 @@
+CloudBridge support for `Google Cloud Platform`_. Compute is provided by `Google
+Compute Engine`_ (GCE). Object storage is provided by `Google Cloud Storage`_
+(GCE).
+
+Security Groups
+~~~~~~~~~~~~~~~
+CloudBridge API lets you control incoming traffic to VM instances by creating
+security groups, adding rules to security groups, and then assigning instances
+to security groups.
+
+GCE does this a little bit differently. GCE lets you assign `tags`_ to VM
+instances. Tags, then, can be used for networking purposes. In particular, you
+can create `firewall rules`_ to control incoming traffic to instances having a
+specific tag. So, to add GCE support to CloudBridge, we simulate security groups
+by tags.
+
+To make this more clear, let us consider the example of adding a rule to a
+security group. When you add a security group rule from the CloudBridge API to
+a security group ``sg``, what really happens is that a firewall with one rule
+is created whose ``targetTags`` is ``[sg]``. This makes sure that the rule
+applies to all instances that have ``sg`` as a tag (in CloudBridge language
+instances belonging to the security group ``sg``).
+
+**Note**: This implementation does not take advantage of the full power of GCE
+firewall format and only creates firewalls with one rule and only can find or
+list firewalls with one rule. This should be OK as long as all firewalls are
+created through the CloudBridge API.
+
+**Note**: The current implementation adds firewalls to the ``default`` network.
+
+.. _`Google Cloud Platform`: https://cloud.google.com/
+.. _`Google Compute Engine`: https://cloud.google.com/compute/docs
+.. _`Google Cloud Storage`: https://cloud.google.com/storage/docs
+.. _`tags`: https://cloud.google.com/compute/docs/reference/latest/instances/
+   setTags
+.. _`firewall rules`: https://cloud.google.com/compute/docs/
+   networking#firewall_rules

+ 333 - 0
cloudbridge/cloud/providers/gce/resources.py

@@ -5,7 +5,18 @@ from cloudbridge.cloud.base.resources import BaseInstanceType
 from cloudbridge.cloud.base.resources import BaseKeyPair
 from cloudbridge.cloud.base.resources import BasePlacementZone
 from cloudbridge.cloud.base.resources import BaseRegion
+from cloudbridge.cloud.base.resources import BaseSecurityGroup
+from cloudbridge.cloud.base.resources import BaseSecurityGroupRule
 
+# Older versions of Python do not have a built-in set data-structure.
+try:
+    set
+except NameError:
+    from sets import Set as set
+
+import hashlib
+import inspect
+import json
 
 class GCEKeyPair(BaseKeyPair):
 
@@ -162,3 +173,325 @@ class GCERegion(BaseRegion):
                  if zone['region'] == self._gce_region['selfLink']]
         return [GCEPlacementZone(self._provider, zone['name'], self.name)
                 for zone in zones]
+
+
+class GCEFirewallsDelegate(object):
+  
+    def __init__(self, provider):
+        self._provider = provider
+        self._list_response = None
+
+    @staticmethod
+    def tag_id(tag):
+        md5 = hashlib.md5()
+        md5.update(tag.encode('ascii'))
+        return md5.hexdigest()
+
+    @property
+    def provider(self):
+        return self._provider
+
+    @property
+    def tags(self):
+        out = set()
+        for firewall in self.iter_firewalls():
+            out.add(firewall['targetTags'][0])
+        return out
+            
+    def get_tag_from_id(self, tag_id):
+        for tag in self.tags:
+            if GCEFirewallsDelegate.tag_id(tag) == tag_id:
+                return tag
+        return None
+
+    def has_tag(self, tag):
+        return tag in self.tags
+
+    def delete_tag_with_id(self, tag_id):
+        tag = self.get_tag_from_id(tag_id)
+        if tag is None:
+            return
+        for firewall in self.iter_firewalls(tag):
+            self._delete_firewall(firewall)
+        self._update_list_response()
+
+    def add_firewall(self, tag, ip_protocol, port, source_range, source_tag,
+                     description):
+        if self.find_firewall(tag, ip_protocol, port, source_range,
+                              source_tag) is not None:
+            return True
+        # Do not let the user accidentally open traffic from the world by not
+        # explicitly specifying the source.
+        if source_tag is None and source_range is None:
+            return False
+        firewall_number = 1
+        suffixes = []
+        for firewall in self.iter_firewalls(tag):
+            suffix = firewall['name'].split('-')[-1]
+            if suffix.isdigit():
+                suffixes.append(int(suffix))
+        for suffix in sorted(suffixes):
+            if firewall_number == suffix:
+                firewall_number += 1
+        firewall = {'name': '%s-rule-%d' % (tag, firewall_number),
+                    'allowed': [{'IPProtocol': str(ip_protocol)}],
+                    'targetTags': [tag]}
+        if description is not None:
+            firewall['description'] = description
+        if port is not None:
+            firewall['allowed'][0]['ports'] = [port]
+        if source_range is not None:
+            firewall['sourceRanges'] = [source_range]
+        if source_tag is not None:
+            firewall['sourceTags'] = [source_tag]
+        project_name = self._provider.project_name
+        try:
+            response = (self._provider.gce_compute
+                                      .firewalls()
+                                      .insert(project=project_name,
+                                              body=firewall)
+                                      .execute())
+            self._provider.wait_for_global_operation(response)
+            # TODO: process the response and handle errors.
+            return True
+        except:
+            return False
+        finally:
+            self._update_list_response()
+
+    def find_firewall(self, tag, ip_protocol, port, source_range, source_tag):
+        if source_range is None and source_tag is None:
+            source_range = '0.0.0.0/0'
+        for firewall in self.iter_firewalls(tag):
+            if firewall['allowed'][0]['IPProtocol'] != ip_protocol:
+                continue
+            if not self._check_list_in_dict(firewall['allowed'][0], 'ports',
+                                            port):
+                continue
+            if not self._check_list_in_dict(firewall, 'sourceRanges',
+                                            source_range):
+                continue
+            if not self._check_list_in_dict(firewall, 'sourceTags', source_tag):
+                continue
+            return firewall['id']
+        return None
+
+    def get_firewall_info(self, firewall_id):
+        info = {}
+        for firewall in self.iter_firewalls():
+            if firewall['id'] != firewall_id:
+                continue
+            if ('sourceRanges' in firewall and
+                len(firewall['sourceRanges']) == 1):
+                info['source_range'] = firewall['sourceRanges'][0]
+            if 'sourceTags' in firewall and len(firewall['sourceTags']) == 1:
+                info['source_tag'] = firewall['sourceTags'][0]
+            if 'targetTags' in firewall and len(firewall['targetTags']) == 1:
+                info['target_tag'] = firewall['targetTags'][0]
+            if 'IPProtocol' in firewall['allowed'][0]:
+                info['ip_protocol'] = firewall['allowed'][0]['IPProtocol']
+            if ('ports' in firewall['allowed'][0] and
+                len(firewall['allowed'][0]['ports']) == 1):
+                info['port'] = firewall['allowed'][0]['ports'][0]
+            return info
+        return info
+
+    def delete_firewall_id(self, firewall_id):
+        for firewall in self.iter_firewalls():
+            if firewall['id'] == firewall_id:
+                self._delete_firewall(firewall)
+        self._update_list_response()
+
+    def iter_firewalls(self, tag=None):
+        if self._list_response is None:
+            self._update_list_response()
+        if 'items' not in self._list_response:
+            return
+        for firewall in self._list_response['items']:
+            if 'targetTags' not in firewall or len(firewall['targetTags']) != 1:
+                continue
+            if 'allowed' not in firewall or len(firewall['allowed']) != 1:
+                continue
+            if tag is None or firewall['targetTags'][0] == tag:
+                yield firewall
+
+    def _delete_firewall(self, firewall):
+        project_name = self._provider.project_name
+        try:
+            response = (self._provider.gce_compute
+                                      .firewalls()
+                                      .delete(project=project_name,
+                                              firewall=firewall['name'])
+                                      .execute())
+            self._provider.wait_for_global_operation(response)
+            # TODO: process the response and handle errors.
+            return True
+        except:
+            return False
+
+    def _update_list_response(self):
+        self._list_response = (
+                self._provider.gce_compute
+                              .firewalls()
+                              .list(project=self._provider.project_name)
+                              .execute())
+
+    def _check_list_in_dict(self, dictionary, field_name, value):
+        if field_name not in dictionary:
+            return value is None
+        if (value is None or
+            len(dictionary[field_name]) != 1 or
+            dictionary[field_name][0] != value):
+            return False
+        return True
+
+
+class GCESecurityGroup(BaseSecurityGroup):
+
+    def __init__(self, delegate, tag, description=None):
+        super(GCESecurityGroup, self).__init__(delegate.provider, tag)
+        self._description = description
+        self._delegate = delegate
+
+    @property
+    def id(self):
+        return GCEFirewallsDelegate.tag_id(self._security_group)
+
+    @property
+    def name(self):
+        return self._security_group
+
+    @property
+    def description(self):
+        if self._description is not None:
+            return self._description
+        for firewall in self._delegate.iter_firewalls(self._security_group):
+            if 'description' in firewall:
+                return firewall['description']
+        return None
+
+    @property
+    def rules(self):
+        out = []
+        for firewall in self._delegate.iter_firewalls(self._security_group):
+            out.append(GCESecurityGroupRule(self._delegate, firewall['id']))
+        return out
+
+    @staticmethod
+    def to_port_range(from_port, to_port):
+        if from_port is not None and to_port is not None:
+            return '%d-%d' % (from_port, to_port)
+        elif from_port is not None:
+            return from_port
+        else:
+            return to_port
+
+    def add_rule(self, ip_protocol, from_port=None, to_port=None,
+                 cidr_ip=None, src_group=None):
+        port = GCESecurityGroup.to_port_range(from_port, to_port)
+        src_tag = src_group.name if src_group is not None else None
+        self._delegate.add_firewall(self._security_group, ip_protocol, port,
+                                    cidr_ip, src_tag, self.description)
+        return self.get_rule(ip_protocol, from_port, to_port, cidr_ip,
+                             src_group)
+
+    def get_rule(self, ip_protocol=None, from_port=None, to_port=None,
+                 cidr_ip=None, src_group=None):
+        port = GCESecurityGroup.to_port_range(from_port, to_port)
+        src_tag = src_group.name if src_group is not None else None
+        firewall_id = self._delegate.find_firewall(
+                self._security_group, ip_protocol, port, cidr_ip, src_tag)
+        if firewall_id is None:
+            return None
+        return GCESecurityGroupRule(self._delegate, firewall_id)
+
+    def to_json(self):
+        attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
+        js = {k: v for(k, v) in attr if not k.startswith('_')}
+        json_rules = [r.to_json() for r in self.rules]
+        js['rules'] = [json.loads(r) for r in json_rules]
+        return json.dumps(js, sort_keys=True)
+
+    def delete(self):
+        for rule in self.rules:
+            rule.delete()
+
+
+class GCESecurityGroupRule(BaseSecurityGroupRule):
+
+    def __init__(self, delegate, firewall_id):
+        super(GCESecurityGroupRule, self).__init__(
+                delegate.provider, firewall_id, None)
+        self._delegate = delegate
+
+    @property
+    def parent(self):
+        info = self._delegate.get_firewall_info(self._rule)
+        if info is None or 'target_tag' not in info:
+            return None
+        return GCESecurityGroup(self._delegate, info['target_tag'])
+
+    @property
+    def id(self):
+        return self._rule
+
+    @property
+    def ip_protocol(self):
+        info = self._delegate.get_firewall_info(self._rule)
+        if info is None or 'ip_protocol' not in info:
+            return None
+        return info['ip_protocol']
+
+    @property
+    def from_port(self):
+        info = self._delegate.get_firewall_info(self._rule)
+        if info is None or 'port' not in info:
+            return 0
+        port = info['port']
+        if port.isdigit():
+            return int(port)
+        parts = port.split('-')
+        if len(parts) > 2 or len(parts) < 1:
+            return 0
+        if parts[0].isdigit():
+            return int(parts[0])
+        return 0
+
+    @property
+    def to_port(self):
+        info = self._delegate.get_firewall_info(self._rule)
+        if info is None or 'port' not in info:
+            return 0
+        port = info['port']
+        if port.isdigit():
+            return int(port)
+        parts = port.split('-')
+        if len(parts) > 2 or len(parts) < 1:
+            return 0
+        if parts[-1].isdigit():
+            return int(parts[-1])
+        return 0
+
+    @property
+    def cidr_ip(self):
+        info = self._delegate.get_firewall_info(self._rule)
+        if info is None or 'source_range' not in info:
+            return None
+        return info['source_range']
+
+    @property
+    def group(self):
+        info = self._delegate.get_firewall_info(self._rule)
+        if info is None or 'source_tag' not in info:
+            return None
+        return GCESecurityGroup(self._delegate, info['source_tag'])
+
+    def to_json(self):
+        attr = inspect.getmembers(self, lambda a: not(inspect.isroutine(a)))
+        js = {k: v for(k, v) in attr if not k.startswith('_')}
+        js['group'] = self.group.id if self.group else ''
+        js['parent'] = self.parent.id if self.parent else ''
+        return json.dumps(js, sort_keys=True)
+
+    def delete(self):
+        self._delegate.delete_firewall_id(self._rule)

+ 32 - 2
cloudbridge/cloud/providers/gce/services.py

@@ -15,6 +15,9 @@ from retrying import retry
 from .resources import GCEInstanceType
 from .resources import GCEKeyPair
 from .resources import GCERegion
+from .resources import GCEFirewallsDelegate
+from .resources import GCESecurityGroup
+from .resources import GCESecurityGroupRule
 
 
 class GCESecurityService(BaseSecurityService):
@@ -24,6 +27,7 @@ class GCESecurityService(BaseSecurityService):
 
         # Initialize provider services
         self._key_pairs = GCEKeyPairService(provider)
+        self._security_groups = GCESecurityGroupService(provider)
 
     @property
     def key_pairs(self):
@@ -31,8 +35,7 @@ class GCESecurityService(BaseSecurityService):
 
     @property
     def security_groups(self):
-        raise NotImplementedError(
-            "GCECloudProvider does not implement this service")
+        return self._security_groups
 
 
 class GCEKeyPairService(BaseKeyPairService):
@@ -190,6 +193,33 @@ class GCESecurityGroupService(BaseSecurityGroupService):
 
     def __init__(self, provider):
         super(GCESecurityGroupService, self).__init__(provider)
+        self._delegate = GCEFirewallsDelegate(provider)
+
+    def get(self, group_id):
+        tag = self._delegate.get_tag_from_id(group_id)
+        return None if tag is None else GCESecurityGroup(self._delegate, tag)
+
+    def list(self, limit=None, marker=None):
+        security_groups = [GCESecurityGroup(self._delegate, x)
+                           for x in self._delegate.tags]
+        return ClientPagedResultList(self.provider, security_groups,
+                                     limit=limit, marker=marker)
+
+    def create(self, name, description):
+        return GCESecurityGroup(self._delegate, name, description)
+
+    def find(self, name, limit=None, marker=None):
+        """
+        Finds a non-empty security group. If a security group with the given
+        name does not exist, or if it does not contain any rules, an empty list
+        is returned.
+        """
+        if self._delegate.has_tag(name):
+            return [GCESecurityGroup(self._delegate, name)]
+        return []
+
+    def delete(self, group_id):
+        return self._delegate.delete_tag_with_id(group_id)
 
 
 class GCEInstanceTypesService(BaseInstanceTypesService):

+ 13 - 9
test/test_security_service.py

@@ -1,5 +1,6 @@
 import json
 from test.helpers import ProviderTestBase
+import time
 import uuid
 
 import test.helpers as helpers
@@ -12,7 +13,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             methodName=methodName, provider=provider)
 
     def test_crud_key_pair_service(self):
-        name = 'cbtestkeypairA-{0}'.format(uuid.uuid4())
+        name = 'cbtestkeypair-a'
         kp = self.provider.security.key_pairs.create(name=name)
         with helpers.cleanup_action(
             lambda:
@@ -67,7 +68,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             "Found a key pair {0} that should not exist?".format(no_kp))
 
     def test_key_pair(self):
-        name = 'cbtestkeypairB-{0}'.format(uuid.uuid4())
+        name = 'cbtestkeypair-b'
         kp = self.provider.security.key_pairs.create(name=name)
         with helpers.cleanup_action(lambda: kp.delete()):
             kpl = self.provider.security.key_pairs.list()
@@ -98,9 +99,11 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
             .format(name))
 
     def test_crud_security_group_service(self):
-        name = 'cbtestsecuritygroupA-{0}'.format(uuid.uuid4())
+        name = 'cbtestsecuritygroup-a'
         sg = self.provider.security.security_groups.create(
             name=name, description=name)
+        #Empty security groups don't exist in GCE. Let's add a dummy rule.
+        sg.add_rule(ip_protocol='tcp')
         with helpers.cleanup_action(
             lambda:
                 self.provider.security.security_groups.delete(group_id=sg.id)
@@ -154,7 +157,7 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
 
     def test_security_group(self):
         """Test for proper creation of a security group."""
-        name = 'cbtestsecuritygroupB-{0}'.format(uuid.uuid4())
+        name = 'cbtestsecuritygroup-b'
         sg = self.provider.security.security_groups.create(
             name=name, description=name)
         with helpers.cleanup_action(lambda: sg.delete()):
@@ -202,15 +205,16 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
 
     def test_security_group_group_role(self):
         """Test for proper creation of a security group rule."""
-        name = 'cbtestsecuritygroupC-{0}'.format(uuid.uuid4())
+        name = 'cbtestsecuritygroup-c'
         sg = self.provider.security.security_groups.create(
             name=name, description=name)
-        with helpers.cleanup_action(lambda: sg.delete()):
+        with helpers.cleanup_action(
+                lambda: None if sg is None else sg.delete()):
             self.assertTrue(
                 len(sg.rules) == 0,
                 "Expected no security group group rule. Got {0}."
                 .format(sg.rules))
-            rule = sg.add_rule(src_group=sg)
+            rule = sg.add_rule(ip_protocol='tcp', src_group=sg)
             self.assertTrue(
                 rule.group.name == name,
                 "Expected security group rule name {0}. Got {1}."
@@ -219,9 +223,9 @@ class CloudSecurityServiceTestCase(ProviderTestBase):
                 r.delete()
             sg = self.provider.security.security_groups.get(sg.id)  # update
             self.assertTrue(
-                len(sg.rules) == 0,
+                sg is None or len(sg.rules) == 0,
                 "Deleting SecurityGroupRule should delete it: {0}".format(
-                    sg.rules))
+                    [] if sg is None else sg.rules))
         sgl = self.provider.security.security_groups.list()
         found_sg = [g for g in sgl if g.name == name]
         self.assertTrue(