Просмотр исходного кода

Merge pull request #218 from gabriel-samfira/fix-memory-leak-in-rpc

Fix memory leak in RPC clients
Gabriel 4 лет назад
Родитель
Сommit
e87dbc3fdc
4 измененных файлов с 47 добавлено и 33 удалено
  1. 1 1
      coriolis/cmd/worker.py
  2. 6 4
      coriolis/providers/backup_writers.py
  3. 33 26
      coriolis/rpc.py
  4. 7 2
      coriolis/service.py

+ 1 - 1
coriolis/cmd/worker.py

@@ -31,7 +31,7 @@ def main():
     server = service.MessagingService(
         constants.WORKER_MAIN_MESSAGING_TOPIC,
         [rpc_server.WorkerServerEndpoint()],
-        rpc_server.VERSION, worker_count=worker_count)
+        rpc_server.VERSION, worker_count=worker_count, init_rpc=False)
     launcher = service.service.launch(
         CONF, server, workers=server.get_workers_count())
     launcher.wait()

+ 6 - 4
coriolis/providers/backup_writers.py

@@ -324,7 +324,7 @@ class SSHBackupWriterImpl(BaseBackupWriterImpl):
             data = self._sender_q.get()
             try:
                 self._send_msg(data)
-            except Exception as err:
+            except BaseException as err:
                 self._exception = err
                 raise
             finally:
@@ -340,7 +340,7 @@ class SSHBackupWriterImpl(BaseBackupWriterImpl):
                     payload["offset"],
                     payload["msg_id"])
                 self._sender_q.put(data)
-            except Exception as err:
+            except BaseException as err:
                 self._exception = err
                 raise
             finally:
@@ -634,9 +634,10 @@ class HTTPBackupWriterImpl(BaseBackupWriterImpl):
                         chunk, constants.COMPRESSION_FORMAT_GZIP)
                     if compressed:
                         send_payload["encoding"] = 'gzip'
-                except Exception as err:
+                except BaseException as err:
                     LOG.exception(err)
                     self._exception = err
+                    self._comp_q.task_done()
                     raise
             send_payload["chunk"] = chunk
             self._sender_q.put(send_payload)
@@ -675,11 +676,12 @@ class HTTPBackupWriterImpl(BaseBackupWriterImpl):
                     raise
             try:
                 send()
-            except Exception as err:
+            except BaseException as err:
                 # record the exception. We need to terminate
                 # the writer if this is set
                 LOG.exception(err)
                 self._exception = err
+                self._sender_q.task_done()
                 raise
             finally:
                 del headers

+ 33 - 26
coriolis/rpc.py

@@ -28,6 +28,7 @@ LOG = logging.getLogger(__name__)
 
 ALLOWED_EXMODS = [
     coriolis.exception.__name__]
+_TRANSPORT = None
 
 
 class RequestContextSerializer(messaging.Serializer):
@@ -65,6 +66,13 @@ def get_server(target, endpoints, serializer=None):
                                     serializer=serializer)
 
 
+def init():
+    global _TRANSPORT
+    if _TRANSPORT is None:
+        _TRANSPORT = _get_transport()
+    return _TRANSPORT
+
+
 class BaseRPCClient(object):
     """ Wrapper for 'oslo_messaging.RPCClient' which automatically
     instantiates and cleans up transports for each call.
@@ -76,43 +84,42 @@ class BaseRPCClient(object):
         if self._timeout is None:
             self._timeout = CONF.default_messaging_timeout
         self._serializer = RequestContextSerializer(serializer)
+        self._transport_conn = None
 
     def __repr__(self):
         return "<RPCClient(target=%s, timeout=%s)>" % (
             self._target, self._timeout)
 
-    @contextlib.contextmanager
-    def _rpc_messaging_client(self):
-        transport = None
-        try:
-            transport = _get_transport()
-            yield messaging.RPCClient(
-                transport, self._target, serializer=self._serializer,
+    @property
+    def _transport(self):
+        global _TRANSPORT
+        if _TRANSPORT is None:
+            if self._transport_conn is None:
+                self._transport_conn = _get_transport()
+            return self._transport_conn
+        else:
+            return _TRANSPORT
+
+    def _rpc_client(self):
+        return messaging.RPCClient(
+                self._transport, self._target,
+                serializer=self._serializer,
                 timeout=self._timeout)
-        finally:
-            if transport:
-                try:
-                    transport.cleanup()
-                except (Exception, KeyboardInterrupt):
-                    LOG.warn(
-                        "Exception occurred while cleaning up transport for "
-                        "RPC client instance '%s'. Error was: %s",
-                        repr(self), utils.get_exception_details())
 
     def _call(self, ctxt, method, **kwargs):
-        with self._rpc_messaging_client() as client:
-            return client.call(ctxt, method, **kwargs)
+        client = self._rpc_client()
+        return client.call(ctxt, method, **kwargs)
 
     def _call_on_host(self, host, ctxt, method, **kwargs):
-        with self._rpc_messaging_client() as client:
-            cctxt = client.prepare(server=host)
-            return cctxt.call(ctxt, method, **kwargs)
+        client = self._rpc_client()
+        cctxt = client.prepare(server=host)
+        return cctxt.call(ctxt, method, **kwargs)
 
     def _cast(self, ctxt, method, **kwargs):
-        with self._rpc_messaging_client() as client:
-            client.cast(ctxt, method, **kwargs)
+        client = self._rpc_client()
+        client.cast(ctxt, method, **kwargs)
 
     def _cast_for_host(self, host, ctxt, method, **kwargs):
-        with self._rpc_messaging_client() as client:
-            cctxt = client.prepare(server=host)
-            cctxt.cast(ctxt, method, **kwargs)
+        client = self._rpc_client()
+        cctxt = client.prepare(server=host)
+        cctxt.cast(ctxt, method, **kwargs)

+ 7 - 2
coriolis/service.py

@@ -98,7 +98,9 @@ def check_locks_dir_empty():
 
 
 class WSGIService(service.ServiceBase):
-    def __init__(self, name, worker_count=None):
+    def __init__(self, name, worker_count=None, init_rpc=True):
+        if init_rpc:
+            rpc.init()
         self._host = CONF.api_migration_listen
         self._port = CONF.api_migration_listen_port
 
@@ -137,7 +139,10 @@ class WSGIService(service.ServiceBase):
 
 
 class MessagingService(service.ServiceBase):
-    def __init__(self, topic, endpoints, version, worker_count=None):
+    def __init__(self, topic, endpoints, version,
+                 worker_count=None, init_rpc=True):
+        if init_rpc:
+            rpc.init()
         target = messaging.Target(topic=topic,
                                   server=utils.get_hostname(),
                                   version=version)