Przeglądaj źródła

Tests update for Instance.create param change

Mock tests will fail now because of https://github.com/spulec/moto/issues/801
Enis Afgan 9 lat temu
rodzic
commit
8ea3475043

+ 40 - 18
cloudbridge/cloud/providers/aws/services.py

@@ -29,7 +29,7 @@ from cloudbridge.cloud.interfaces.resources import KeyPair
 from cloudbridge.cloud.interfaces.resources import MachineImage
 from cloudbridge.cloud.interfaces.resources import Network
 from cloudbridge.cloud.interfaces.resources import PlacementZone
-# from cloudbridge.cloud.interfaces.resources import SecurityGroup
+from cloudbridge.cloud.interfaces.resources import SecurityGroup
 from cloudbridge.cloud.interfaces.resources import Snapshot
 from cloudbridge.cloud.interfaces.resources import Volume
 
@@ -583,19 +583,48 @@ class AWSInstanceService(BaseInstanceService):
                 raise ValueError(exc)
             return sn_id, zone_id
 
+        def _get_security_groups(security_groups, vpc_id=None, obj=False):
+            """
+            Resolve exact security groups to use.
+
+            :type security_groups: A ``list`` of ``SecurityGroup`` objects or
+                                   a list of ``str`` names.
+            :param security_groups: A list of ``SecurityGroup`` objects or a
+                                    list of ``SecurityGroup`` names, which
+                                    should be resolved.
+
+            :type vpc_id: ``str``
+            :param vpc_id: ID of the network within which to launch.
+
+            :type obj: ``bool``
+            :param obj: If True, return provider-native security group objects.
+                        Otherwise, return the IDs.
+
+            :rtype: list
+            :return: provider-native security group objects or the IDs (see
+                    ``obj`` param).
+            """
+            if isinstance(security_groups, list) and \
+               isinstance(security_groups[0], SecurityGroup):
+                return [sg._security_group if obj else sg.id
+                        for sg in security_groups]
+            else:
+                flters = {'group_name': security_groups}
+                if vpc_id:
+                    flters['vpc_id'] = vpc_id
+                sgs = self.provider.ec2_conn.get_all_security_groups(
+                    filters=flters)
+                return list(set([sg if obj else sg.id for sg in sgs]))
+
         if zone_id and vpc_id and security_groups:
             exc = "No subnets found in zone {0} for network {1}.".format(
                 zone_id, vpc_id)
             flters = {'availabilityZone': zone_id, 'state': 'available',
                       'vpcId': vpc_id}
             sn_id, _ = _get_potential_subnets(flters, exc)
-            sgs = self.provider.ec2_conn.get_all_security_groups(
-                filters={'vpc_id': vpc_id, 'group_name': security_groups})
-            sg_ids = list(set([sg.id for sg in sgs]))
+            sg_ids = _get_security_groups(security_groups, vpc_id)
         elif vpc_id and security_groups:
-            sgs = self.provider.ec2_conn.get_all_security_groups(
-                filters={'vpc_id': vpc_id, 'group_name': security_groups})
-            sg_ids = list(set([sg.id for sg in sgs]))
+            sg_ids = _get_security_groups(security_groups, vpc_id)
             exc = "No subnets found in network {0}.".format(vpc_id)
             flters = {'state': 'available', 'vpcId': vpc_id}
             sn_id, zone_id = _get_potential_subnets(flters, exc)
@@ -607,8 +636,7 @@ class AWSInstanceService(BaseInstanceService):
             sn_id, _ = _get_potential_subnets(flters, exc)
             sg_ids = None
         elif zone_id and security_groups:
-            sgs = self.provider.ec2_conn.get_all_security_groups(
-                filters={'group_name': security_groups})
+            sgs = _get_security_groups(security_groups, obj=True)
             # Get VPCs the supplied SGs belong to
             vpc_ids = list(set([sg.vpc_id for sg in sgs if sg.vpc_id]))
             vpcs = []
@@ -616,12 +644,9 @@ class AWSInstanceService(BaseInstanceService):
                 vpcs = self.provider.vpc_conn.get_all_vpcs(vpc_ids=vpc_ids)
             exc = ("No default network found for zone {0} and security groups "
                    "{1}".format(zone_id, security_groups))
-            print ("vpcs: %s" % vpcs)
             default_vpc = _get_default_vpc(vpcs, exc)
             # Filter only the SGs within the default VPC
-            sgs = self.provider.ec2_conn.get_all_security_groups(
-                filters={'vpc_id': default_vpc, 'group_name': security_groups})
-            sg_ids = list(set([sg.id for sg in sgs]))
+            sg_ids = _get_security_groups(security_groups, default_vpc)
             flters = {'availabilityZone': zone_id, 'state': 'available',
                       'vpc_id': default_vpc}
             exc = "No subnets found in zone {0} for default network {1}." \
@@ -644,8 +669,7 @@ class AWSInstanceService(BaseInstanceService):
             sn_id, _ = _get_potential_subnets(flters, exc)
             sg_ids = None
         elif security_groups:
-            sgs = self.provider.ec2_conn.get_all_security_groups(
-                filters={'group_name': security_groups})
+            sgs = _get_security_groups(security_groups, obj=True)
             # Get VPCs the supplied SGs belong to
             vpc_ids = list(set([sg.vpc_id for sg in sgs if sg.vpc_id]))
             vpcs = []
@@ -655,9 +679,7 @@ class AWSInstanceService(BaseInstanceService):
                 security_groups)
             default_vpc = _get_default_vpc(vpcs, exc)
             # Filter only the SGs within the default VPC
-            sgs = self.provider.ec2_conn.get_all_security_groups(
-                filters={'vpc_id': default_vpc, 'group_name': security_groups})
-            sg_ids = list(set([sg.id for sg in sgs]))
+            sg_ids = _get_security_groups(security_groups, default_vpc)
             flters = {'state': 'available', 'vpcId': default_vpc}
             exc = "No subnets found in network {0}.".format(default_vpc)
             sn_id, zone_id = _get_potential_subnets(flters, exc)

+ 1 - 0
docs/topics/overview.rst

@@ -8,6 +8,7 @@ Introductions to all the key parts of CloudBridge you'll need to know:
     How to install CloudBridge <install.rst>
     Connection and authentication setup <setup.rst>
     Launching instances <launch.rst>
+    Networking <networking.rst>
     Object states and lifecycles <object_lifecycles.rst>
     Paging and iteration <paging_and_iteration.rst>
     Using block storage <block_storage.rst>

+ 4 - 4
test/helpers.py

@@ -52,9 +52,8 @@ def cleanup_action(cleanup_func):
 
 TEST_DATA_CONFIG = {
     "AWSCloudProvider": {
-        "image": os.environ.get('CB_IMAGE_AWS', 'ami-d85e75b0'),
-        "instance_type": os.environ.get('CB_INSTANCE_TYPE_AWS',
-                                        't1.micro'),
+        "image": os.environ.get('CB_IMAGE_AWS', 'ami-5ac2cd4d'),
+        "instance_type": os.environ.get('CB_INSTANCE_TYPE_AWS', 't2.micro'),
         "placement": os.environ.get('CB_PLACEMENT_AWS', 'us-east-1a'),
     },
     "OpenStackCloudProvider": {
@@ -94,12 +93,13 @@ def delete_test_network(network):
 
 
 def create_test_instance(
-        provider, instance_name, zone=None, launch_config=None,
+        provider, instance_name, network, zone=None, launch_config=None,
         key_pair=None, security_groups=None):
     return provider.compute.instances.create(
         instance_name,
         get_provider_test_data(provider, 'image'),
         get_provider_test_data(provider, 'instance_type'),
+        network=network,
         zone=zone,
         key_pair=key_pair,
         security_groups=security_groups,

+ 2 - 3
test/test_compute_service.py

@@ -292,9 +292,8 @@ class CloudComputeServiceTestCase(ProviderTestBase):
             self.provider,
             name,
             network=net,
-            zone=helpers.get_provider_test_data(
-                self.provider,
-                'placement'),
+            # We don't have a way to match the test net placement and this zone
+            # zone=helpers.get_provider_test_data(self.provider, 'placement'),
             launch_config=lc)
 
         def cleanup(instance, net):