Sfoglia il codice sorgente

Simplified middleware handling through decorators

Nuwan Goonasekera 7 anni fa
parent
commit
c09c55d597
2 ha cambiato i file con 189 aggiunte e 50 eliminazioni
  1. 123 17
      cloudbridge/cloud/base/middleware.py
  2. 66 33
      test/test_middleware_system.py

+ 123 - 17
cloudbridge/cloud/base/middleware.py

@@ -1,48 +1,154 @@
+import inspect
 import logging
-from abc import abstractmethod
+import sys
 
+import six
 
+from ..base.events import InterceptingEventHandler
+from ..base.events import ObservingEventHandler
+from ..interfaces.exceptions import CloudBridgeBaseException
 from ..interfaces.middleware import Middleware
 from ..interfaces.middleware import MiddlewareManager
 
 log = logging.getLogger(__name__)
 
 
+def intercept(event_pattern, priority):
+    def deco(f):
+        # Mark function as having an event_handler so we can discover it
+        # The callback cannot be set to f as it is not bound yet and will be
+        # set during auto discovery
+        f.__event_handler = InterceptingEventHandler(
+            event_pattern, priority, None)
+        return f
+    return deco
+
+
+def observe(event_pattern, priority):
+    def deco(f):
+        # Mark function as having an event_handler so we can discover it
+        # The callback cannot be set to f as it is not bound yet and will be
+        # set during auto discovery
+        f.__event_handler = ObservingEventHandler(
+            event_pattern, priority, None)
+        return f
+    return deco
+
+
 class SimpleMiddlewareManager(MiddlewareManager):
 
     def __init__(self, event_manager):
         self.events = event_manager
-        self.middleware = []
+        self.middleware_list = []
 
     def add(self, middleware):
-        self.middleware.append(middleware)
-        middleware.install(self.events)
+        if isinstance(middleware, Middleware):
+            m = middleware
+        else:
+            m = AutoDiscoveredMiddleware(middleware)
+        m.install(self.events)
+        self.middleware_list.append(m)
+        return m
 
     def remove(self, middleware):
         middleware.uninstall()
-        self.middleware.remove(middleware)
+        self.middleware_list.remove(middleware)
 
 
 class BaseMiddleware(Middleware):
 
-    def install(self, event_manager):
+    def __init__(self):
         self.event_handlers = []
-        self.events = event_manager
-        self.setup()
 
-    @abstractmethod
-    def setup(self):
-        pass
+    def install(self, event_manager):
+        self.events = event_manager
+        discovered_handlers = self.discover_handlers(self)
+        self.add_handlers(discovered_handlers)
 
-    def add_observer(self, event_pattern, priority, callback):
-        handler = self.events.observe(event_pattern, priority, callback)
-        self.event_handlers.append(handler)
+    def add_handlers(self, handlers):
+        if not hasattr(self, "event_handlers"):
+            # In case the user forgot to call super class init
+            self.event_handlers = []
+        for handler in handlers:
+            self.events.subscribe(handler)
+        self.event_handlers.extend(handlers)
 
-    def add_interceptor(self, event_pattern, priority, callback):
-        handler = self.events.intercept(event_pattern, priority, callback)
-        self.event_handlers.append(handler)
+    def discover_handlers(self, class_or_obj):
+        discovered_handlers = []
+        for _, func in inspect.getmembers(class_or_obj, inspect.ismethod):
+            handler = getattr(func, "__event_handler", None)
+            if handler:
+                # Set the properly bound method as the callback
+                handler.callback = func
+                discovered_handlers.append(handler)
+        return discovered_handlers
 
     def uninstall(self):
         for handler in self.event_handlers:
             handler.unsubscribe()
+        self.event_handlers = []
         self.events = None
+
+
+class AutoDiscoveredMiddleware(BaseMiddleware):
+
+    def __init__(self, class_or_obj):
+        super(AutoDiscoveredMiddleware, self).__init__()
+        self.obj_to_discover = class_or_obj
+
+    def install(self, event_manager):
+        super(AutoDiscoveredMiddleware, self).install(event_manager)
+        discovered_handlers = self.discover_handlers(self.obj_to_discover)
+        self.add_handlers(discovered_handlers)
+
+
+class EventDebugLoggingMiddleware(BaseMiddleware):
+    """
+    Logs all event parameters. This middleware should not be enabled other
+    than for debugging, as it could log sensitive parameters such as
+    access keys.
+    """
+    def setup(self):
+        self.add_observer(
+            event_pattern="*", priority=1100, callback=self.pre_log_event)
+        self.add_interceptor(
+            event_pattern="*", priority=1150, callback=self.post_log_event)
+
+    @observe(event_pattern="*", priority=1100)
+    def pre_log_event(self, **kwargs):
+        log.debug("Event: {0} invoked with args: {1}".format(
+            kwargs.get("event"), kwargs))
+
+    @intercept(event_pattern="*", priority=1150)
+    def post_log_event(self, **kwargs):
+        next_handler = kwargs.pop("next_handler")
+        result = next_handler.invoke(**kwargs)
+        log.debug("Event: {0} result: {1}".format(
+            kwargs.get("event"), result))
+        return result
+
+
+class ExceptionWrappingMiddleware(BaseMiddleware):
+    """
+    Wraps all unhandled exceptions in cloudbridge exceptions.
+    """
+    def setup(self):
+        self.add_interceptor(
+            event_pattern="*", priority=1050, callback=self.wrap_exception)
+
+    def wrap_exception(self, **kwargs):
+        next_handler = kwargs.pop("next_handler")
+        try:
+            return next_handler.invoke(**kwargs)
+        except Exception as e:
+            if isinstance(e, CloudBridgeBaseException):
+                raise
+            else:
+                ex_type, ex_value, traceback = sys.exc_info()
+                cb_ex = CloudBridgeBaseException(
+                    "CloudBridgeBaseException: {0} from exception type: {1}"
+                    .format(ex_value, ex_type))
+                if sys.version_info >= (3, 0):
+                    six.raise_from(cb_ex, e)
+                else:
+                    six.reraise(CloudBridgeBaseException, cb_ex, traceback)

+ 66 - 33
test/test_middleware_system.py

@@ -3,6 +3,8 @@ import unittest
 from cloudbridge.cloud.base.events import SimpleEventDispatcher
 from cloudbridge.cloud.base.middleware import BaseMiddleware
 from cloudbridge.cloud.base.middleware import SimpleMiddlewareManager
+from cloudbridge.cloud.base.middleware import intercept
+from cloudbridge.cloud.base.middleware import observe
 from cloudbridge.cloud.interfaces.middleware import Middleware
 
 
@@ -42,15 +44,11 @@ class MiddlewareSystemTestCase(unittest.TestCase):
             def __init__(self):
                 self.invocation_order = ""
 
-            def setup(self):
-                self.add_observer(event_pattern="some.event.*", priority=1000,
-                                  callback=self.my_callback_obs)
-                self.add_interceptor(event_pattern="some.*", priority=900,
-                                     callback=self.my_callback_intcpt)
-
+            @observe(event_pattern="some.event.*", priority=1000)
             def my_callback_obs(self, **kwargs):
                 self.invocation_order += "observe"
 
+            @intercept(event_pattern="some.event.*", priority=900)
             def my_callback_intcpt(self, **kwargs):
                 self.invocation_order += "intercept_"
                 return kwargs.get('next_handler').invoke(**kwargs)
@@ -76,38 +74,22 @@ class MiddlewareSystemTestCase(unittest.TestCase):
 
         class DummyMiddleWare1(BaseMiddleware):
 
-            def __init__(self):
-                self.invocation_order = ""
-
-            def setup(self):
-                self.add_observer(event_pattern="some.really.*", priority=1000,
-                                  callback=self.my_callback_obs1)
-                self.add_interceptor(event_pattern="some.*", priority=900,
-                                     callback=self.my_callback_intcpt2)
-
+            @observe(event_pattern="some.really.*", priority=1000)
             def my_callback_obs1(self, **kwargs):
-                self.invocation_order += "observe"
+                pass
 
+            @intercept(event_pattern="some.*", priority=900)
             def my_callback_intcpt2(self, **kwargs):
-                self.invocation_order += "intercept_"
                 return kwargs.get('next_handler').invoke(**kwargs)
 
         class DummyMiddleWare2(BaseMiddleware):
 
-            def __init__(self):
-                self.invocation_order = ""
-
-            def setup(self):
-                self.add_observer(event_pattern="some.really.*", priority=1050,
-                                  callback=self.my_callback_obs1)
-                self.add_interceptor(event_pattern="*", priority=950,
-                                     callback=self.my_callback_intcpt2)
-
-            def my_callback_obs1(self, **kwargs):
-                self.invocation_order += "observe"
+            @observe(event_pattern="some.really.*", priority=1050)
+            def my_callback_obs3(self, **kwargs):
+                pass
 
-            def my_callback_intcpt2(self, **kwargs):
-                self.invocation_order += "intercept_"
+            @intercept(event_pattern="*", priority=950)
+            def my_callback_intcpt4(self, **kwargs):
                 return kwargs.get('next_handler').invoke(**kwargs)
 
         dispatcher = SimpleEventDispatcher()
@@ -120,8 +102,8 @@ class MiddlewareSystemTestCase(unittest.TestCase):
 
         # Callbacks in both middleware classes should be registered
         self.assertListEqual(
-            [middleware1.my_callback_intcpt2, middleware2.my_callback_intcpt2,
-             middleware1.my_callback_obs1, middleware2.my_callback_obs1],
+            [middleware1.my_callback_intcpt2, middleware2.my_callback_intcpt4,
+             middleware1.my_callback_obs1, middleware2.my_callback_obs3],
             [handler.callback for handler
              in dispatcher.get_handlers_for_event(EVENT_NAME)])
 
@@ -129,6 +111,57 @@ class MiddlewareSystemTestCase(unittest.TestCase):
 
         # Only middleware2 callbacks should be registered
         self.assertListEqual(
-            [middleware2.my_callback_intcpt2, middleware2.my_callback_obs1],
+            [middleware2.my_callback_intcpt4, middleware2.my_callback_obs3],
+            [handler.callback for handler in
+             dispatcher.get_handlers_for_event(EVENT_NAME)])
+
+        # add middleware back to check that internal state is properly handled
+        manager.add(middleware1)
+
+        # should one again equal original list
+        self.assertListEqual(
+            [middleware1.my_callback_intcpt2, middleware2.my_callback_intcpt4,
+             middleware1.my_callback_obs1, middleware2.my_callback_obs3],
+            [handler.callback for handler
+             in dispatcher.get_handlers_for_event(EVENT_NAME)])
+
+    def test_automatic_middleware(self):
+        EVENT_NAME = "another.interesting.event.occurred"
+
+        class SomeDummyClass(object):
+
+            @observe(event_pattern="another.really.*", priority=1000)
+            def not_a_match(self, **kwargs):
+                pass
+
+            @intercept(event_pattern="another.*", priority=900)
+            def my_callback_intcpt2(self, **kwargs):
+                pass
+
+            def not_an_event_handler(self, **kwargs):
+                pass
+
+            @observe(event_pattern="another.interesting.*", priority=1000)
+            def my_callback_obs1(self, **kwargs):
+                pass
+
+        dispatcher = SimpleEventDispatcher()
+        manager = SimpleMiddlewareManager(dispatcher)
+        some_obj = SomeDummyClass()
+        middleware = manager.add(some_obj)
+        dispatcher.emit(self, EVENT_NAME)
+
+        # Middleware should be discovered even if class containing interceptors
+        # doesn't inherit from Middleware
+        self.assertListEqual(
+            [some_obj.my_callback_intcpt2, some_obj.my_callback_obs1],
+            [handler.callback for handler
+             in dispatcher.get_handlers_for_event(EVENT_NAME)])
+
+        manager.remove(middleware)
+
+        # Callbacks should be correctly removed
+        self.assertListEqual(
+            [],
             [handler.callback for handler in
              dispatcher.get_handlers_for_event(EVENT_NAME)])