Przeglądaj źródła

Middleware support in factory with constructor option

almahmoud 7 lat temu
rodzic
commit
614ca04850

+ 2 - 6
cloudbridge/cloud/base/provider.py

@@ -8,8 +8,6 @@ try:
 except ImportError:  # Python 2
     from ConfigParser import SafeConfigParser as ConfigParser
 
-from pyeventsystem.middleware import SimpleMiddlewareManager
-
 import six
 
 from ..base.middleware import ExceptionWrappingMiddleware
@@ -83,13 +81,11 @@ class BaseConfiguration(Configuration):
 
 
 class BaseCloudProvider(CloudProvider):
-    def __init__(self, config, middleware_list=[]):
+    def __init__(self, config, middleware_manager=None):
         self._config = BaseConfiguration(config)
         self._config_parser = ConfigParser()
         self._config_parser.read(CloudBridgeConfigLocations)
-        self._middleware = SimpleMiddlewareManager()
-        for each_middleware in middleware_list:
-            self._middleware.add(each_middleware)
+        self._middleware = middleware_manager.generate_simple_manager()
         self.add_required_middleware()
 
     @property

+ 28 - 6
cloudbridge/cloud/factory.py

@@ -4,6 +4,8 @@ import logging
 import pkgutil
 from collections import defaultdict
 
+from pyeventsystem.middleware import SimpleMiddlewareManager
+
 from cloudbridge.cloud import providers
 from cloudbridge.cloud.interfaces import CloudProvider
 from cloudbridge.cloud.interfaces import TestMockHelperMixin
@@ -12,6 +14,29 @@ from cloudbridge.cloud.interfaces import TestMockHelperMixin
 log = logging.getLogger(__name__)
 
 
+# Todo: Move to pyeventsystem if we're keeping this logic
+class ParentMiddlewareManager(SimpleMiddlewareManager):
+
+    def __init__(self, event_manager=None):
+        super(ParentMiddlewareManager, self).__init__(event_manager)
+        self.middleware_constructors = []
+
+    def add_constructor(self, middleware_class, *args):
+        self.middleware_constructors.append((middleware_class, args))
+
+    def remove_constructor(self, middleware_class, *args):
+        self.middleware_constructors.remove((middleware_class, args))
+
+    def generate_simple_manager(self):
+        new_manager = SimpleMiddlewareManager()
+        for middleware in self.middleware_list:
+            new_manager.add(middleware)
+        for constructor, args in self.middleware_constructors:
+            m = constructor(*args)
+            new_manager.add(m)
+        return new_manager
+
+
 class ProviderList(object):
     AWS = 'aws'
     AZURE = 'azure'
@@ -27,17 +52,14 @@ class CloudProviderFactory(object):
     """
 
     def __init__(self):
-        self._middleware = []
+        self._middleware = ParentMiddlewareManager()
         self.provider_list = defaultdict(dict)
         log.debug("Providers List: %s", self.provider_list)
 
     @property
-    def added_middleware(self):
+    def middleware(self):
         return self._middleware
 
-    def add_middleware(self, middleware):
-        self._middleware.append(middleware)
-
     def register_provider_class(self, cls):
         """
         Registers a provider class with the factory. The class must
@@ -144,7 +166,7 @@ class CloudProviderFactory(object):
                 'A provider with name {0} could not be'
                 ' found'.format(name))
         log.debug("Created '%s' provider", name)
-        return provider_class(config, self.added_middleware)
+        return provider_class(config, self.middleware)
 
     def get_provider_class(self, name):
         """

+ 2 - 2
cloudbridge/cloud/providers/aws/provider.py

@@ -17,8 +17,8 @@ class AWSCloudProvider(BaseCloudProvider):
     PROVIDER_ID = 'aws'
     AWS_INSTANCE_DATA_DEFAULT_URL = "http://cloudve.org/cb-aws-vmtypes.json"
 
-    def __init__(self, config, middleware_list=[]):
-        super(AWSCloudProvider, self).__init__(config, middleware_list)
+    def __init__(self, config, middleware_manager=None):
+        super(AWSCloudProvider, self).__init__(config, middleware_manager)
 
         # Initialize cloud connection fields
         # These are passed as-is to Boto

+ 2 - 2
cloudbridge/cloud/providers/azure/provider.py

@@ -23,8 +23,8 @@ log = logging.getLogger(__name__)
 class AzureCloudProvider(BaseCloudProvider):
     PROVIDER_ID = 'azure'
 
-    def __init__(self, config, middleware_list=[]):
-        super(AzureCloudProvider, self).__init__(config, middleware_list)
+    def __init__(self, config, middleware_manager=[]):
+        super(AzureCloudProvider, self).__init__(config, middleware_manager)
 
         # mandatory config values
         self.subscription_id = self._get_config_value(

+ 2 - 2
cloudbridge/cloud/providers/gcp/provider.py

@@ -196,8 +196,8 @@ class GCPCloudProvider(BaseCloudProvider):
 
     PROVIDER_ID = 'gcp'
 
-    def __init__(self, config, middleware_list=[]):
-        super(GCPCloudProvider, self).__init__(config, middleware_list)
+    def __init__(self, config, middleware_manager=[]):
+        super(GCPCloudProvider, self).__init__(config, middleware_manager)
 
         # Disable warnings about file_cache not being available when using
         # oauth2client >= 4.0.0.

+ 2 - 2
cloudbridge/cloud/providers/mock/provider.py

@@ -25,9 +25,9 @@ class MockAWSCloudProvider(AWSCloudProvider, TestMockHelperMixin):
     """
     PROVIDER_ID = 'mock'
 
-    def __init__(self, config, middleware_list=[]):
+    def __init__(self, config, middleware_manager=[]):
         self.setUpMock()
-        super(MockAWSCloudProvider, self).__init__(config, middleware_list)
+        super(MockAWSCloudProvider, self).__init__(config, middleware_manager)
 
     def setUpMock(self):
         """

+ 2 - 2
cloudbridge/cloud/providers/openstack/provider.py

@@ -31,8 +31,8 @@ class OpenStackCloudProvider(BaseCloudProvider):
 
     PROVIDER_ID = 'openstack'
 
-    def __init__(self, config, middleware_list=[]):
-        super(OpenStackCloudProvider, self).__init__(config, middleware_list)
+    def __init__(self, config, middleware_manager=[]):
+        super(OpenStackCloudProvider, self).__init__(config, middleware_manager)
 
         # Initialize cloud connection fields
         self.username = self._get_config_value(

+ 65 - 7
test/test_cloud_factory.py

@@ -90,17 +90,75 @@ class CloudFactoryTestCase(unittest.TestCase):
                         factory.get_all_provider_classes())
 
     def test_middleware_inherited(self):
-        return_str = "hello world"
+        start_count = 10
+
         class SomeDummyClass(object):
+            count = start_count
 
             @intercept(event_pattern="*", priority=2499)
-            def return_hello_world(self, event_args, *args, **kwargs):
-                return return_str
+            def return_incremented(self, event_args, *args, **kwargs):
+                self.count += 1
+                return self.count
 
         factory = CloudProviderFactory()
         some_obj = SomeDummyClass()
-        factory.add_middleware(some_obj)
+        factory.middleware.add(some_obj)
+        provider_name = cb_helpers.get_env("CB_TEST_PROVIDER", "aws")
+        first_prov = factory.create_provider(provider_name, {})
+        # Any dispatched event should be intercepted and increment the count
+        first_prov.storage.volumes.get("anything")
+        self.assertEqual(first_prov.networking.networks.get("anything"),
+                         start_count + 2)
+        second_prov = factory.create_provider(provider_name, {})
+        # This count should be independent of the previous one
+        self.assertEqual(second_prov.networking.networks.get("anything"),
+                         start_count + 3)
+
+    def test_middleware_inherited_constructor(self):
+        start_count = 10
+        increment = 2
+
+        class SomeDummyClass(object):
+            count = start_count
+
+            @intercept(event_pattern="*", priority=2499)
+            def return_incremented(self, event_args, *args, **kwargs):
+                self.count += 1
+                return self.count
+
+        factory = CloudProviderFactory()
+        factory.middleware.add_constructor(SomeDummyClass)
+        provider_name = cb_helpers.get_env("CB_TEST_PROVIDER", "aws")
+        first_prov = factory.create_provider(provider_name, {})
+        # Any dispatched event should be intercepted and increment the count
+        first_prov.storage.volumes.get("anything")
+        self.assertEqual(first_prov.networking.networks.get("anything"),
+                         start_count + 2)
+        second_prov = factory.create_provider(provider_name, {})
+        # This count should be independent of the previous one
+        self.assertEqual(second_prov.networking.networks.get("anything"),
+                         start_count + 1)
+
+        class SomeDummyClassWithArgs(object):
+            def __init__(self, start, increment):
+                self.count = start
+                self.increment = increment
+
+            @intercept(event_pattern="*", priority=2499)
+            def return_incremented(self, event_args, *args, **kwargs):
+                self.count += self.increment
+                return self.count
+
+        factory = CloudProviderFactory()
+        factory.middleware.add_constructor(SomeDummyClassWithArgs,
+                                           start_count, increment)
         provider_name = cb_helpers.get_env("CB_TEST_PROVIDER", "aws")
-        prov = factory.create_provider(provider_name, {})
-        # Any dispatched event should be intercepted and return the string
-        self.assertEqual(prov.storage.volumes.get("anything"), return_str)
+        first_prov = factory.create_provider(provider_name, {})
+        # Any dispatched event should be intercepted and increment the count
+        first_prov.storage.volumes.get("anything")
+        self.assertEqual(first_prov.networking.networks.get("anything"),
+                         start_count + 2*increment)
+        second_prov = factory.create_provider(provider_name, {})
+        # This count should be independent of the previous one
+        self.assertEqual(second_prov.networking.networks.get("anything"),
+                         start_count + increment)