Alessandro Pilotti 10 лет назад
Родитель
Сommit
e7963ac576
1 измененных файлов с 44 добавлено и 54 удалено
  1. 44 54
      coriolis/worker/rpc/server.py

+ 44 - 54
coriolis/worker/rpc/server.py

@@ -46,12 +46,9 @@ class WorkerServerEndpoint(object):
         self._server = utils.get_hostname()
         self._server = utils.get_hostname()
         self._rpc_conductor_client = rpc_conductor_client.ConductorClient()
         self._rpc_conductor_client = rpc_conductor_client.ConductorClient()
 
 
-    def _get_task_export_path(self, task_id):
-        return os.path.join(CONF.worker.export_base_path, task_id)
-
     def _cleanup_task_resources(self, task_id, task_info=None):
     def _cleanup_task_resources(self, task_id, task_info=None):
         try:
         try:
-            export_path = self._get_task_export_path(task_id)
+            export_path = _get_task_export_path(task_id)
             if (not task_info or export_path not in
             if (not task_info or export_path not in
                     task_info.get(TMP_DIRS_KEY, [])):
                     task_info.get(TMP_DIRS_KEY, [])):
                 # Don't remove folder if it's needed by the dependent tasks
                 # Don't remove folder if it's needed by the dependent tasks
@@ -78,10 +75,14 @@ class WorkerServerEndpoint(object):
         except psutil.NoSuchProcess:
         except psutil.NoSuchProcess:
             LOG.info("Task process not found: %s", process_id)
             LOG.info("Task process not found: %s", process_id)
 
 
-    def _exec_task_process(self, ctxt, task_id, target, args):
+    def _exec_task_process(self, ctxt, task_id, task_type, origin, destination,
+                           instance, task_info):
         mp_ctx = multiprocessing.get_context('spawn')
         mp_ctx = multiprocessing.get_context('spawn')
         mp_q = mp_ctx.Queue()
         mp_q = mp_ctx.Queue()
-        p = mp_ctx.Process(target=target, args=(args + (ctxt, task_id, mp_q,)))
+        p = mp_ctx.Process(
+            target=_task_process,
+            args=(ctxt, task_id, task_type, origin, destination, instance,
+                  task_info, mp_q))
 
 
         p.start()
         p.start()
         LOG.info("Task process started: %s", task_id)
         LOG.info("Task process started: %s", task_id)
@@ -101,37 +102,14 @@ class WorkerServerEndpoint(object):
     def exec_task(self, ctxt, task_id, task_type, origin, destination,
     def exec_task(self, ctxt, task_id, task_type, origin, destination,
                   instance, task_info):
                   instance, task_info):
         try:
         try:
-            new_task_info = None
-
-            if task_type == constants.TASK_TYPE_EXPORT_INSTANCE:
-                provider = factory.get_provider(
-                    origin["type"], constants.PROVIDER_TYPE_EXPORT)
-                export_path = self._get_task_export_path(task_id)
-                if not os.path.exists(export_path):
-                    os.makedirs(export_path)
-
-                new_task_info = self._exec_task_process(
-                    ctxt, task_id, _export_instance,
-                    (provider, origin.get("connection_info", {}),
-                     instance, export_path))
-
-                new_task_info[TMP_DIRS_KEY] = [export_path]
-
-            elif task_type == constants.TASK_TYPE_IMPORT_INSTANCE:
-                provider = factory.get_provider(
-                    destination["type"], constants.PROVIDER_TYPE_IMPORT)
-
-                self._exec_task_process(
-                    ctxt, task_id, _import_instance,
-                    (provider, destination.get("connection_info", {}),
-                     destination["target_environment"],
-                     instance, task_info))
-            else:
-                raise exception.CoriolisException("Unknown task type: %s" %
-                                                  task_type)
+            new_task_info = self._exec_task_process(
+                ctxt, task_id, task_type, origin, destination,
+                instance, task_info)
+
+            if new_task_info:
+                LOG.info("Task info: %s", new_task_info)
 
 
             LOG.info("Task completed: %s", task_id)
             LOG.info("Task completed: %s", task_id)
-            LOG.info("Task info: %s", new_task_info)
             self._rpc_conductor_client.task_completed(ctxt, task_id,
             self._rpc_conductor_client.task_completed(ctxt, task_id,
                                                       new_task_info)
                                                       new_task_info)
 
 
@@ -145,35 +123,47 @@ class WorkerServerEndpoint(object):
             self._remove_tmp_dirs(task_info)
             self._remove_tmp_dirs(task_info)
 
 
 
 
-def _export_instance(provider, connection_info, instance, export_path,
-                     ctxt, task_id, mp_q):
+def _get_task_export_path(task_id, create=False):
+    export_path = os.path.join(CONF.worker.export_base_path, task_id)
+    if create and not os.path.exists(export_path):
+        os.makedirs(export_path)
+    return export_path
+
+
+def _task_process(ctxt, task_id, task_type, origin, destination, instance,
+                  task_info, mp_q):
     try:
     try:
         # Setting up logging, needed since this is a new process
         # Setting up logging, needed since this is a new process
         utils.setup_logging()
         utils.setup_logging()
 
 
+        if task_type == constants.TASK_TYPE_EXPORT_INSTANCE:
+            provider_type = constants.PROVIDER_TYPE_EXPORT
+            data = origin
+        elif task_type == constants.TASK_TYPE_IMPORT_INSTANCE:
+            provider_type = constants.PROVIDER_TYPE_IMPORT
+            data = destination
+        else:
+            raise exception.NotFound(
+                "Unknown task type: %s" % task_type)
+
+        provider = factory.get_provider(data["type"], provider_type)
         progress_update_manager = _ConductorProgressUpdateManager(ctxt,
         progress_update_manager = _ConductorProgressUpdateManager(ctxt,
                                                                   task_id)
                                                                   task_id)
         provider.set_progress_update_manager(progress_update_manager)
         provider.set_progress_update_manager(progress_update_manager)
-        result = provider.export_instance(ctxt, connection_info, instance,
-                                          export_path)
-        mp_q.put(result)
-    except Exception as ex:
-        mp_q.put(str(ex))
-        LOG.exception(ex)
 
 
+        connection_info = data.get("connection_info", {})
+        target_environment = data.get("target_environment", {})
 
 
-def _import_instance(provider, connection_info, target_environment, instance,
-                     export_info, ctxt, task_id, mp_q):
-    try:
-        # Setting up logging, needed since this is a new process
-        utils.setup_logging()
+        if provider_type == constants.PROVIDER_TYPE_EXPORT:
+            export_path = _get_task_export_path(task_id, create=True)
 
 
-        progress_update_manager = _ConductorProgressUpdateManager(ctxt,
-                                                                  task_id)
-        provider.set_progress_update_manager(progress_update_manager)
-        result = provider.import_instance(ctxt, connection_info,
-                                          target_environment, instance,
-                                          export_info)
+            result = provider.export_instance(ctxt, connection_info, instance,
+                                              export_path)
+            result[TMP_DIRS_KEY] = [export_path]
+        else:
+            result = provider.import_instance(ctxt, connection_info,
+                                              target_environment, instance,
+                                              task_info)
         mp_q.put(result)
         mp_q.put(result)
     except Exception as ex:
     except Exception as ex:
         mp_q.put(str(ex))
         mp_q.put(str(ex))