소스 검색

Add 'asynchronous' kwarg to WorkerServerEndpoint.exec_task.

Nashwan Azhari 5 년 전
부모
커밋
504e25c14d
2개의 변경된 파일54개의 추가작업 그리고 32개의 파일을 삭제
  1. 9 1
      coriolis/worker/rpc/client.py
  2. 45 31
      coriolis/worker/rpc/server.py

+ 9 - 1
coriolis/worker/rpc/client.py

@@ -39,7 +39,15 @@ class WorkerClient(rpc.BaseRPCClient):
         self._cast(
             ctxt, 'exec_task', task_id=task_id, task_type=task_type,
             origin=origin, destination=destination, instance=instance,
-            task_info=task_info)
+            task_info=task_info, asynchronous=True)
+
+    def run_task(self, ctxt, server, task_id, task_type, origin, destination,
+                 instance, task_info):
+        cctxt = self._client.prepare(server=server)
+        cctxt.cast(
+            ctxt, 'exec_task', task_id=task_id, task_type=task_type,
+            origin=origin, destination=destination, instance=instance,
+            task_info=task_info, asynchronous=False)
 
     def cancel_task(self, ctxt, task_id, process_id, force):
         return self._call(

+ 45 - 31
coriolis/worker/rpc/server.py

@@ -224,7 +224,7 @@ class WorkerServerEndpoint(object):
         return result
 
     def _exec_task_process(self, ctxt, task_id, task_type, origin, destination,
-                           instance, task_info):
+                           instance, task_info, report_to_conductor=True):
         mp_ctx = multiprocessing.get_context('spawn')
         mp_q = mp_ctx.Queue()
         mp_log_q = mp_ctx.Queue()
@@ -237,23 +237,25 @@ class WorkerServerEndpoint(object):
             ctxt, task_id, task_type, origin, destination)
 
         try:
-            LOG.debug(
-                "Attempting to set task host on Conductor for task '%s'.",
-                task_id)
-            self._rpc_conductor_client.set_task_host(
-                ctxt, task_id, self._server)
+            if report_to_conductor:
+                LOG.debug(
+                    "Attempting to set task host on Conductor for task '%s'.",
+                    task_id)
+                self._rpc_conductor_client.set_task_host(
+                    ctxt, task_id, self._server)
             LOG.debug(
                 "Attempting to start process for task with ID '%s'", task_id)
             self._start_process_with_custom_library_paths(
                 p, extra_library_paths)
             LOG.info("Task process started: %s", task_id)
+            if report_to_conductor:
+                LOG.debug(
+                    "Attempting to set task process on Conductor for task '%s'.",
+                    task_id)
+                self._rpc_conductor_client.set_task_process(
+                    ctxt, task_id, p.pid)
             LOG.debug(
-                "Attempting to set task process on Conductor for task '%s'.",
-                task_id)
-            self._rpc_conductor_client.set_task_process(
-                ctxt, task_id, p.pid)
-            LOG.debug(
-                "Successfully started and retported task process for task "
+                "Successfully started and reported task process for task "
                 "with ID '%s'.", task_id)
         except (Exception, KeyboardInterrupt) as ex:
             LOG.debug(
@@ -290,39 +292,51 @@ class WorkerServerEndpoint(object):
         return result
 
     def exec_task(self, ctxt, task_id, task_type, origin, destination,
-                  instance, task_info):
+                  instance, task_info, asynchronous=True):
         try:
             task_result = self._exec_task_process(
                 ctxt, task_id, task_type, origin, destination,
-                instance, task_info)
+                instance, task_info, report_to_conductor=asynchronous)
 
             LOG.info(
                 "Output of completed %s task with ID %s: %s",
                 task_type, task_id,
                 utils.sanitize_task_info(task_result))
 
+            if not asynchronous:
+                return task_result
             self._rpc_conductor_client.task_completed(
                 ctxt, task_id, task_result)
         except exception.TaskProcessCanceledException as ex:
-            LOG.debug(
-                "Task with ID '%s' appears to have been cancelled. "
-                "Confirming cancellation to Conductor now. Error was: %s",
-                task_id, utils.get_exception_details())
-            LOG.exception(ex)
-            self._rpc_conductor_client.confirm_task_cancellation(
-                ctxt, task_id, str(ex))
+            if asynchronous:
+                LOG.debug(
+                    "Task with ID '%s' appears to have been cancelled. "
+                    "Confirming cancellation to Conductor now. Error was: %s",
+                    task_id, utils.get_exception_details())
+                LOG.exception(ex)
+                self._rpc_conductor_client.confirm_task_cancellation(
+                    ctxt, task_id, str(ex))
+            else:
+                raise
         except exception.NoSuitableWorkerServiceError as ex:
-            LOG.warn(
-                "A conductor-side scheduling error has occurred following the "
-                "completion of task '%s'. Ignoring. Error was: %s",
-                task_id, utils.get_exception_details())
+            if asynchronous:
+                LOG.warn(
+                    "A conductor-side scheduling error has occurred following "
+                    "the completion of task '%s'. Ignoring. Error was: %s",
+                    task_id, utils.get_exception_details())
+            else:
+                raise
         except Exception as ex:
-            LOG.debug(
-                "Task with ID '%s' has error'd out. Reporting error to "
-                "Conductor now. Error was: %s",
-                task_id, utils.get_exception_details())
-            LOG.exception(ex)
-            self._rpc_conductor_client.set_task_error(ctxt, task_id, str(ex))
+            if asynchronous:
+                LOG.debug(
+                    "Task with ID '%s' has error'd out. Reporting error to "
+                    "Conductor now. Error was: %s",
+                    task_id, utils.get_exception_details())
+                LOG.exception(ex)
+                self._rpc_conductor_client.set_task_error(
+                        ctxt, task_id, str(ex))
+            else:
+                raise
 
     def get_endpoint_instances(self, ctxt, platform_name, connection_info,
                                source_environment, marker, limit,