Преглед изворни кода

Make mock providers play nice with non-mock providers.

nuwan_ag пре 10 година
родитељ
комит
7465dc3d11

+ 20 - 7
cloudbridge/providers/aws/impl.py

@@ -6,11 +6,12 @@ import os
 
 
 import boto
 import boto
 from boto.ec2.regioninfo import RegionInfo
 from boto.ec2.regioninfo import RegionInfo
-
-from cloudbridge.providers.base import BaseCloudProvider
 from moto.ec2 import mock_ec2
 from moto.ec2 import mock_ec2
 from moto.s3 import mock_s3
 from moto.s3 import mock_s3
 
 
+from cloudbridge.providers.base import BaseCloudProvider
+from test.helpers import TestMockHelperMixin
+
 from .services import AWSBlockStoreService
 from .services import AWSBlockStoreService
 from .services import AWSComputeService
 from .services import AWSComputeService
 from .services import AWSImageService
 from .services import AWSImageService
@@ -94,11 +95,23 @@ class AWSCloudProviderV1(BaseCloudProvider):
         return s3_conn
         return s3_conn
 
 
 
 
-class MockAWSCloudProvider(AWSCloudProviderV1):
+class MockAWSCloudProvider(AWSCloudProviderV1, TestMockHelperMixin):
 
 
     def __init__(self, config):
     def __init__(self, config):
-        ec2mock = mock_ec2()
-        ec2mock.start()
-        s3mock = mock_s3()
-        s3mock.start()
         super(MockAWSCloudProvider, self).__init__(config)
         super(MockAWSCloudProvider, self).__init__(config)
+
+    def setUpMock(self):
+        """
+        Let Moto take over all socket communications
+        """
+        self.ec2mock = mock_ec2()
+        self.ec2mock.start()
+        self.s3mock = mock_s3()
+        self.s3mock.start()
+
+    def tearDownMock(self):
+        """
+        Stop Moto intercepting all socket communications
+        """
+        self.s3mock.stop()
+        self.ec2mock.stop()

+ 42 - 7
test/helpers.py

@@ -14,12 +14,6 @@ def parse_bool(val):
         return False
         return False
 
 
 
 
-TEST_WAIT_INTERVAL = 0 if parse_bool(
-    os.environ.get(
-        "CB_USE_MOCK_DRIVERS",
-        True)) else 5
-
-
 @contextmanager
 @contextmanager
 def exception_action(cleanup_func):
 def exception_action(cleanup_func):
     """
     """
@@ -77,12 +71,42 @@ def create_test_instance(provider, instance_name):
         get_provider_test_data(provider, 'instance_type'))
         get_provider_test_data(provider, 'instance_type'))
 
 
 
 
+def get_provider_wait_interval(provider):
+    if isinstance(provider, TestMockHelperMixin):
+        return 0
+    else:
+        return 5
+
+
 def get_test_instance(provider, name):
 def get_test_instance(provider, name):
     instance = create_test_instance(provider, name)
     instance = create_test_instance(provider, name)
-    instance.wait_till_ready(interval=TEST_WAIT_INTERVAL)
+    instance.wait_till_ready(interval=get_provider_wait_interval(provider))
     return instance
     return instance
 
 
 
 
+class TestMockHelperMixin(object):
+    """
+    A helper class that providers mock drivers can use to be notified when a
+    test setup/teardown occurs. This is useful when activating libraries
+    like HTTPretty which take over socket communications.
+    """
+
+    def setUpMock(self):
+        """
+        Called before a test is started.
+        """
+        raise NotImplementedError(
+            'TestMockHelperMixin.setUpMock not implemented')
+
+    def tearDownMock(self):
+        """
+        Called before test teardown.
+        """
+        raise NotImplementedError(
+            'TestMockHelperMixin.tearDownMock not implemented by this'
+            ' provider')
+
+
 class ProviderTestBase(object):
 class ProviderTestBase(object):
 
 
     """
     """
@@ -96,6 +120,17 @@ class ProviderTestBase(object):
         unittest.TestCase.__init__(self, methodName=methodName)
         unittest.TestCase.__init__(self, methodName=methodName)
         self.provider = provider
         self.provider = provider
 
 
+    def setUp(self):
+        if isinstance(self.provider, TestMockHelperMixin):
+            self.provider.setUpMock()
+
+    def tearDown(self):
+        if isinstance(self.provider, TestMockHelperMixin):
+            self.provider.tearDownMock()
+
+    def get_test_wait_interval(self):
+        return get_provider_wait_interval(self.provider)
+
 
 
 class ProviderTestCaseGenerator():
 class ProviderTestCaseGenerator():
 
 

+ 9 - 8
test/test_provider_block_store_service.py

@@ -23,7 +23,7 @@ class ProviderBlockStoreServiceTestCase(ProviderTestBase):
             1,
             1,
             helpers.get_provider_test_data(self.provider, "placement"))
             helpers.get_provider_test_data(self.provider, "placement"))
         with helpers.exception_action(lambda: test_vol.delete()):
         with helpers.exception_action(lambda: test_vol.delete()):
-            test_vol.wait_till_ready(interval=helpers.TEST_WAIT_INTERVAL)
+            test_vol.wait_till_ready(interval=self.get_test_wait_interval())
             volumes = self.provider.block_store.volumes.list_volumes()
             volumes = self.provider.block_store.volumes.list_volumes()
             found_volumes = [vol for vol in volumes if vol.name == name]
             found_volumes = [vol for vol in volumes if vol.name == name]
             self.assertTrue(
             self.assertTrue(
@@ -34,7 +34,7 @@ class ProviderBlockStoreServiceTestCase(ProviderTestBase):
             test_vol.wait_for(
             test_vol.wait_for(
                 [VolumeState.DELETED, VolumeState.UNKNOWN],
                 [VolumeState.DELETED, VolumeState.UNKNOWN],
                 terminal_states=[VolumeState.ERROR],
                 terminal_states=[VolumeState.ERROR],
-                interval=helpers.TEST_WAIT_INTERVAL)
+                interval=self.get_test_wait_interval())
             volumes = self.provider.block_store.volumes.list_volumes()
             volumes = self.provider.block_store.volumes.list_volumes()
             found_volumes = [vol for vol in volumes if vol.name == name]
             found_volumes = [vol for vol in volumes if vol.name == name]
             self.assertTrue(
             self.assertTrue(
@@ -55,17 +55,17 @@ class ProviderBlockStoreServiceTestCase(ProviderTestBase):
             test_vol = self.provider.block_store.volumes.create_volume(
             test_vol = self.provider.block_store.volumes.create_volume(
                 name, 1, test_instance.placement_zone)
                 name, 1, test_instance.placement_zone)
             with helpers.exception_action(lambda: test_vol.delete()):
             with helpers.exception_action(lambda: test_vol.delete()):
-                test_vol.wait_till_ready(interval=helpers.TEST_WAIT_INTERVAL)
+                test_vol.wait_till_ready(interval=self.get_test_wait_interval())
                 test_vol.attach(test_instance, '/dev/sda2')
                 test_vol.attach(test_instance, '/dev/sda2')
                 test_vol.wait_for(
                 test_vol.wait_for(
                     [VolumeState.IN_USE],
                     [VolumeState.IN_USE],
                     terminal_states=[VolumeState.ERROR, VolumeState.DELETED],
                     terminal_states=[VolumeState.ERROR, VolumeState.DELETED],
-                    interval=helpers.TEST_WAIT_INTERVAL)
+                    interval=self.get_test_wait_interval())
                 test_vol.detach()
                 test_vol.detach()
                 test_vol.wait_for(
                 test_vol.wait_for(
                     [VolumeState.AVAILABLE],
                     [VolumeState.AVAILABLE],
                     terminal_states=[VolumeState.ERROR, VolumeState.DELETED],
                     terminal_states=[VolumeState.ERROR, VolumeState.DELETED],
-                    interval=helpers.TEST_WAIT_INTERVAL)
+                    interval=self.get_test_wait_interval())
                 test_vol.delete()
                 test_vol.delete()
 
 
     def test_crud_snapshot(self):
     def test_crud_snapshot(self):
@@ -80,7 +80,7 @@ class ProviderBlockStoreServiceTestCase(ProviderTestBase):
             1,
             1,
             helpers.get_provider_test_data(self.provider, "placement"))
             helpers.get_provider_test_data(self.provider, "placement"))
         with helpers.exception_action(lambda: test_vol.delete()):
         with helpers.exception_action(lambda: test_vol.delete()):
-            test_vol.wait_till_ready(interval=helpers.TEST_WAIT_INTERVAL)
+            test_vol.wait_till_ready(interval=self.get_test_wait_interval())
             snap_name = "CBSnapshot-{0}".format(name)
             snap_name = "CBSnapshot-{0}".format(name)
             test_snap = test_vol.create_snapshot(name=snap_name,
             test_snap = test_vol.create_snapshot(name=snap_name,
                                                  description=snap_name)
                                                  description=snap_name)
@@ -90,10 +90,11 @@ class ProviderBlockStoreServiceTestCase(ProviderTestBase):
                 snap.wait_for(
                 snap.wait_for(
                     [SnapshotState.UNKNOWN],
                     [SnapshotState.UNKNOWN],
                     terminal_states=[SnapshotState.ERROR],
                     terminal_states=[SnapshotState.ERROR],
-                    interval=helpers.TEST_WAIT_INTERVAL)
+                    interval=self.get_test_wait_interval())
 
 
             with helpers.exception_action(lambda: cleanup_snap(test_snap)):
             with helpers.exception_action(lambda: cleanup_snap(test_snap)):
-                test_snap.wait_till_ready(interval=helpers.TEST_WAIT_INTERVAL)
+                test_snap.wait_till_ready(
+                    interval=self.get_test_wait_interval())
                 snaps = self.provider.block_store.snapshots.list_snapshots()
                 snaps = self.provider.block_store.snapshots.list_snapshots()
                 found_snaps = [snap for snap in snaps
                 found_snaps = [snap for snap in snaps
                                if snap.name == snap_name]
                                if snap.name == snap_name]

+ 2 - 2
test/test_provider_compute_service.py

@@ -17,7 +17,7 @@ class ProviderComputeServiceTestCase(ProviderTestBase):
             uuid.uuid4())
             uuid.uuid4())
         inst = helpers.create_test_instance(self.provider, name)
         inst = helpers.create_test_instance(self.provider, name)
         with helpers.exception_action(lambda: inst.terminate()):
         with helpers.exception_action(lambda: inst.terminate()):
-            inst.wait_till_ready(interval=helpers.TEST_WAIT_INTERVAL)
+            inst.wait_till_ready(interval=self.get_test_wait_interval())
             all_instances = self.provider.compute.list_instances()
             all_instances = self.provider.compute.list_instances()
             found_instances = [i for i in all_instances if i.name == name]
             found_instances = [i for i in all_instances if i.name == name]
             self.assertTrue(
             self.assertTrue(
@@ -28,7 +28,7 @@ class ProviderComputeServiceTestCase(ProviderTestBase):
             inst.wait_for(
             inst.wait_for(
                 [InstanceState.TERMINATED, InstanceState.UNKNOWN],
                 [InstanceState.TERMINATED, InstanceState.UNKNOWN],
                 terminal_states=[InstanceState.ERROR],
                 terminal_states=[InstanceState.ERROR],
-                interval=helpers.TEST_WAIT_INTERVAL)
+                interval=self.get_test_wait_interval())
             deleted_inst = self.provider.compute.get_instance(inst.instance_id)
             deleted_inst = self.provider.compute.get_instance(inst.instance_id)
             self.assertTrue(
             self.assertTrue(
                 deleted_inst is None or deleted_inst.state in (
                 deleted_inst is None or deleted_inst.state in (

+ 3 - 2
test/test_provider_image_service.py

@@ -24,7 +24,8 @@ class ProviderImageServiceTestCase(ProviderTestBase):
             name = "CBUnitTestListImg-{0}".format(uuid.uuid4())
             name = "CBUnitTestListImg-{0}".format(uuid.uuid4())
             test_image = test_instance.create_image(name)
             test_image = test_instance.create_image(name)
             with helpers.exception_action(lambda: test_image.delete()):
             with helpers.exception_action(lambda: test_image.delete()):
-                test_image.wait_till_ready(interval=helpers.TEST_WAIT_INTERVAL)
+                test_image.wait_till_ready(
+                    interval=self.get_test_wait_interval())
                 images = self.provider.images.list_images()
                 images = self.provider.images.list_images()
                 found_images = [image for image in images
                 found_images = [image for image in images
                                 if image.name == name]
                                 if image.name == name]
@@ -36,7 +37,7 @@ class ProviderImageServiceTestCase(ProviderTestBase):
                 test_image.wait_for(
                 test_image.wait_for(
                     [MachineImageState.UNKNOWN],
                     [MachineImageState.UNKNOWN],
                     terminal_states=[MachineImageState.ERROR],
                     terminal_states=[MachineImageState.ERROR],
-                    interval=helpers.TEST_WAIT_INTERVAL)
+                    interval=self.get_test_wait_interval())
             # TODO: Images take a long time to deregister on EC2. Needs
             # TODO: Images take a long time to deregister on EC2. Needs
             # investigation
             # investigation
 #                 images = self.provider.images.list_images()
 #                 images = self.provider.images.list_images()