Selaa lähdekoodia

Add self-registration for workers.

Nashwan Azhari 5 vuotta sitten
vanhempi
sitoutus
33c53142a2

+ 1 - 1
coriolis/api/v1/views/endpoint_view.py

@@ -14,7 +14,7 @@ def _format_endpoint(req, endpoint, keys=None):
         transform(k, v) for k, v in endpoint.items()))
         transform(k, v) for k, v in endpoint.items()))
     mapped_regions = endpoint_dict.get('mapped_regions', [])
     mapped_regions = endpoint_dict.get('mapped_regions', [])
     endpoint_dict['mapped_regions'] = [
     endpoint_dict['mapped_regions'] = [
-        reg['region_id'] for reg in mapped_regions]
+        reg['id'] for reg in mapped_regions]
 
 
     return endpoint_dict
     return endpoint_dict
 
 

+ 5 - 1
coriolis/api/v1/views/region_view.py

@@ -15,7 +15,11 @@ def _format_region(req, region, keys=None):
 
 
     mapped_endpoints = region_dict.get('mapped_endpoints', [])
     mapped_endpoints = region_dict.get('mapped_endpoints', [])
     region_dict['mapped_endpoints'] = [
     region_dict['mapped_endpoints'] = [
-        endp['endpoint_id'] for endp in mapped_endpoints]
+        endp['id'] for endp in mapped_endpoints]
+
+    mapped_services = region_dict.get('mapped_services', [])
+    region_dict['mapped_services'] = [
+        svc['id'] for svc in mapped_services]
 
 
     return region_dict
     return region_dict
 
 

+ 1 - 1
coriolis/api/v1/views/service_view.py

@@ -15,7 +15,7 @@ def _format_service(req, service, keys=None):
 
 
     mapped_regions = service_dict.get('mapped_regions', [])
     mapped_regions = service_dict.get('mapped_regions', [])
     service_dict['mapped_regions'] = [
     service_dict['mapped_regions'] = [
-        mapping['region_id'] for mapping in mapped_regions]
+        mapping['id'] for mapping in mapped_regions]
 
 
     return service_dict
     return service_dict
 
 

+ 13 - 2
coriolis/conductor/rpc/client.py

@@ -338,10 +338,21 @@ class ConductorClient(object):
             ctxt, 'delete_region', region_id=region_id)
             ctxt, 'delete_region', region_id=region_id)
 
 
     def register_service(
     def register_service(
-            self, ctxt, host, binary, topic, enabled, mapped_regions):
+            self, ctxt, host, binary, topic, enabled, mapped_regions,
+            providers=None, specs=None):
         return self._client.call(
         return self._client.call(
             ctxt, 'register_service', host=host, binary=binary,
             ctxt, 'register_service', host=host, binary=binary,
-            topic=topic, enabled=enabled, mapped_regions=mapped_regions)
+            topic=topic, enabled=enabled, mapped_regions=mapped_regions,
+            providers=providers, specs=specs)
+
+    def check_service_registered(self, ctxt, host, binary, topic):
+        return self._client.call(
+            ctxt, 'check_service_registered', host=host, binary=binary,
+            topic=topic)
+
+    def refresh_service_status(self, ctxt, service_id):
+        return self._client.call(
+            ctxt, 'refresh_service_status', service_id=service_id)
 
 
     def get_services(self, ctxt):
     def get_services(self, ctxt):
         return self._client.call(ctxt, 'get_services')
         return self._client.call(ctxt, 'get_services')

+ 171 - 73
coriolis/conductor/rpc/server.py

@@ -4,6 +4,7 @@
 import copy
 import copy
 import functools
 import functools
 import random
 import random
+import time
 import uuid
 import uuid
 
 
 from oslo_concurrency import lockutils
 from oslo_concurrency import lockutils
@@ -169,14 +170,29 @@ def service_synchronized(func):
 class ConductorServerEndpoint(object):
 class ConductorServerEndpoint(object):
     def __init__(self):
     def __init__(self):
         self._licensing_client = licensing_client.LicensingClient.from_env()
         self._licensing_client = licensing_client.LicensingClient.from_env()
-        self._rpc_worker_client = rpc_worker_client.WorkerClient()
-        self._replica_cron_client = rpc_cron_client.ReplicaCronClient()
-        self._scheduler_client = rpc_scheduler_client.SchedulerClient()
+
+    # NOTE(aznashwan): it is unsafe to fork processes with pre-instantiated
+    # oslo_messaging clients as the underlying eventlet thread queues will
+    # be invalidated. Considering this class both serves from a "main
+    # process" as well as forking child processes, it is safest to
+    # re-instantiate the clients every time:
+    @property
+    def _rpc_worker_client(self):
+        return rpc_worker_client.WorkerClient()
+
+    @property
+    def _scheduler_client(self):
+        return rpc_scheduler_client.SchedulerClient()
+
+    @property
+    def _replica_cron_client(self):
+        return rpc_cron_client.ReplicaCronClient()
 
 
     def get_all_diagnostics(self, ctxt):
     def get_all_diagnostics(self, ctxt):
         diagnostics = [
         diagnostics = [
             self.get_diagnostics(ctxt),
             self.get_diagnostics(ctxt),
-            self._replica_cron_client.get_diagnostics(ctxt)]
+            self._replica_cron_client.get_diagnostics(ctxt),
+            self._scheduler_client.get_diagnostics(ctxt)]
         worker_diagnostics = []
         worker_diagnostics = []
         for worker_service in self._scheduler_client.get_workers_for_specs(
         for worker_service in self._scheduler_client.get_workers_for_specs(
                 ctxt):
                 ctxt):
@@ -214,23 +230,43 @@ class ConductorServerEndpoint(object):
             return service
             return service
         return db_api.get_service(ctxt, service['id'])
         return db_api.get_service(ctxt, service['id'])
 
 
+    def _get_worker_rpc_for_host(self, host, *client_args, **client_kwargs):
+        rpc_client_class = RPC_TOPIC_TO_CLIENT_CLASS_MAP[
+            constants.WORKER_MAIN_MESSAGING_TOPIC]
+        topic = constants.SERVICE_MESSAGING_TOPIC_FORMAT % ({
+            "main_topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
+            "host": host})
+        return rpc_client_class(topic=topic, *client_args, **client_kwargs)
+
     def _get_worker_service_rpc_for_specs(
     def _get_worker_service_rpc_for_specs(
-            self, ctxt, provider_requirements=None, region_ids=None,
+            self, ctxt, provider_requirements=None, region_sets=None,
             enabled=True, random_choice=False, raise_on_no_matches=True):
             enabled=True, random_choice=False, raise_on_no_matches=True):
+        requirements_str = (
+            "enabled=%s; region_sets=%s; provider_requirements=%s" % (
+                enabled, region_sets, provider_requirements))
+        LOG.info(
+            "Requesting Worker Service from scheduler with the following "
+            "specifications: %s", requirements_str)
         services = self._scheduler_client.get_workers_for_specs(
         services = self._scheduler_client.get_workers_for_specs(
             ctxt, provider_requirements=provider_requirements,
             ctxt, provider_requirements=provider_requirements,
-            region_ids=region_ids, enabled=enabled)
-        services = self._scheduler_client.get_workers_for_specs(ctxt)
+            region_sets=region_sets, enabled=enabled)
         if not services:
         if not services:
             if raise_on_no_matches:
             if raise_on_no_matches:
                 raise exception.NoSuitableWorkerServiceError()
                 raise exception.NoSuitableWorkerServiceError()
             return None
             return None
+        LOG.debug(
+            "Was offered Worker Services with the following IDs for "
+            "requirements '%s': %s",
+            requirements_str, [s["id"] for s in services])
 
 
         selected_service = services[0]
         selected_service = services[0]
         if random_choice:
         if random_choice:
             selected_service = random.choice(services)
             selected_service = random.choice(services)
         service = db_api.get_service(ctxt, selected_service["id"])
         service = db_api.get_service(ctxt, selected_service["id"])
 
 
+        LOG.info(
+            "Was offered Worker Service with ID '%s' for requirements: %s",
+            service.id, requirements_str)
         return self._get_rpc_client_for_service(service)
         return self._get_rpc_client_for_service(service)
 
 
     def _check_delete_reservation_for_transfer(self, transfer_action):
     def _check_delete_reservation_for_transfer(self, transfer_action):
@@ -351,7 +387,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_INSTANCES]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_INSTANCES]})
         return worker_rpc.get_endpoint_instances(
         return worker_rpc.get_endpoint_instances(
@@ -364,7 +400,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_INSTANCES]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_INSTANCES]})
 
 
@@ -378,7 +414,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [
                 endpoint.type: [
                     constants.PROVIDER_TYPE_SOURCE_ENDPOINT_OPTIONS]})
                     constants.PROVIDER_TYPE_SOURCE_ENDPOINT_OPTIONS]})
@@ -392,7 +428,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [
                 endpoint.type: [
                     constants.PROVIDER_TYPE_DESTINATION_ENDPOINT_OPTIONS]})
                     constants.PROVIDER_TYPE_DESTINATION_ENDPOINT_OPTIONS]})
@@ -404,7 +440,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_NETWORKS]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_NETWORKS]})
 
 
@@ -416,7 +452,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_STORAGE]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_STORAGE]})
 
 
@@ -428,7 +464,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
 
 
@@ -440,7 +476,7 @@ class ConductorServerEndpoint(object):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
 
 
@@ -453,7 +489,7 @@ class ConductorServerEndpoint(object):
 
 
         worker_rpc = self._get_worker_service_rpc_for_specs(
         worker_rpc = self._get_worker_service_rpc_for_specs(
             ctxt, enabled=True,
             ctxt, enabled=True,
-            region_ids=[m['region_id'] for m in endpoint.mapped_regions],
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
             provider_requirements={
             provider_requirements={
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
                 endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
 
 
@@ -529,35 +565,31 @@ class ConductorServerEndpoint(object):
         }
         }
 
 
     def _get_worker_service_rpc_for_task(
     def _get_worker_service_rpc_for_task(
-            self, ctxt, task, origin_endpoint, destination_endpoint):
+            self, ctxt, task, origin_endpoint, destination_endpoint,
+            retry_count=5, retry_period=2):
         LOG.debug(
         LOG.debug(
-            "Requesting Worker Service for task with ID '%s' (type) '%s' "
-            "from endpoints '%s' to '%s'", task.id, task.task_type,
-            origin_endpoint.id, destination_endpoint.id)
+            "Compiling required Worker Service specs for task with "
+            "ID '%s' (type '%s') from endpoints '%s' to '%s'",
+            task.id, task.task_type, origin_endpoint.id,
+            destination_endpoint.id)
         task_cls = tasks_factory.get_task_runner_class(task.task_type)
         task_cls = tasks_factory.get_task_runner_class(task.task_type)
 
 
         # determine required Coriolis regions based on the endpoints:
         # determine required Coriolis regions based on the endpoints:
-        required_regions = []
-        required_platform = task_cls.get_required_platform()
+        required_region_sets = []
         origin_endpoint_region_ids = [
         origin_endpoint_region_ids = [
-            m.region_id for m in origin_endpoint.mapped_regions]
+            r.id for r in origin_endpoint.mapped_regions]
         destination_endpoint_region_ids = [
         destination_endpoint_region_ids = [
-            m.region_id for m in origin_endpoint.mapped_regions]
-
-        if required_platform == constants.TASK_PLATFORM_SOURCE:
-            required_regions = origin_endpoint_region_ids
-        if required_platform == constants.TASK_PLATFORM_DESTINATION:
-            required_regions = destination_endpoint_region_ids
-        if required_platform == constants.TASK_PLATFORM_BILATERAL:
-            # NOTE: backwards-compatibility for endpoints with
-            # no associated regions:
-            if not origin_endpoint_region_ids and (
-                    not destination_endpoint_region_ids):
-                required_regions = []
-            else:
-                required_regions = list(
-                    set(origin_endpoint_region_ids).intersection(
-                        set(destination_endpoint_region_ids)))
+            r.id for r in destination_endpoint.mapped_regions]
+
+        required_platform = task_cls.get_required_platform()
+        if required_platform in (
+                constants.TASK_PLATFORM_SOURCE,
+                constants.TASK_PLATFORM_BILATERAL):
+            required_region_sets.append(origin_endpoint_region_ids)
+        if required_platform in (
+                constants.TASK_PLATFORM_DESTINATION,
+                constants.TASK_PLATFORM_BILATERAL):
+            required_region_sets.append(destination_endpoint_region_ids)
 
 
         # determine provider requirements:
         # determine provider requirements:
         provider_requirements = {}
         provider_requirements = {}
@@ -572,28 +604,50 @@ class ConductorServerEndpoint(object):
                     constants.PROVIDER_PLATFORM_DESTINATION])
                     constants.PROVIDER_PLATFORM_DESTINATION])
 
 
         worker_rpc = None
         worker_rpc = None
-        try:
-            worker_rpc = self._get_worker_service_rpc_for_specs(
-                ctxt, provider_requirements=provider_requirements,
-                region_ids=required_regions, enabled=True)
-        except Exception as ex:
-            LOG.warn(
-                "Failed to schedule task with ID '%s'. Marking as such.")
-            message = (
-                "Failed to schedule task. This may indicate that there are no "
-                "Coriolis Worker services able to perform the task on the "
-                "platforms and in the Coriolis Regions required by the selected"
-                " source/destination Coriolis Endpoints. Please review the "
-                "scheduler logs for more exact details. "
-                "Error message was: %s" % str(ex))
-            db_api.set_task_status(
-                ctxt, task.id, constants.TASK_STATUS_FAILED_TO_SCHEDULE,
-                exception_details=message)
-            raise
-
-        return worker_rpc
+        exceptions = []
+        for i in range(retry_count):
+            try:
+                LOG.debug(
+                    "Requesting Worker Service for task with ID '%s' (type "
+                    "'%s') from endpoints '%s' to '%s'", task.id,
+                    task.task_type, origin_endpoint.id,
+                    destination_endpoint.id)
+                worker_rpc = self._get_worker_service_rpc_for_specs(
+                    ctxt, provider_requirements=provider_requirements,
+                    region_sets=required_region_sets, enabled=True)
+                LOG.debug(
+                    "Scheduler has granted Worker Service for task with ID "
+                    "'%s' (type '%s') from endpoints '%s' to '%s'",
+                    task.id, task.task_type, origin_endpoint.id,
+                    destination_endpoint.id)
+                return worker_rpc
+            except Exception as ex:
+                LOG.warn(
+                    "Failed to schedule task with ID '%s' (attempt %d/%d). "
+                    "waiting %d seconds and then retrying. Error was: %s",
+                    task.id, i+1, retry_count, utils.get_exception_details())
+                exceptions.append(ex)
+                time.sleep(retry_period)
+
+        errors_str = ""
+        nerrors = len(exceptions)
+        for i, ex in enumerate(exceptions):
+            errors_str = "%s; (%d/%d) %s" % (
+                errors_str, i+1, nerrors, str(ex))
+        message = (
+            "Failed to schedule task. This may indicate that there are no "
+            "Coriolis Worker services able to perform the task on the "
+            "platforms and in the Coriolis Regions required by the "
+            "selected source/destination Coriolis Endpoints. Please review"
+            " the scheduler logs for more exact details. "
+            "Encountered errors were: %s" % errors_str[2:])
+        db_api.set_task_status(
+            ctxt, task.id, constants.TASK_STATUS_FAILED_TO_SCHEDULE,
+            exception_details=message)
 
 
-    def _begin_tasks(self, ctxt, execution, task_info={}):
+    def _begin_tasks(
+            self, ctxt, execution, task_info={},
+            scheduling_retry_count=5, scheduling_retry_period=2):
         """ Starts all non-error-only tasks which have no depencies. """
         """ Starts all non-error-only tasks which have no depencies. """
         if not ctxt.trust_id:
         if not ctxt.trust_id:
             keystone.create_trust(ctxt)
             keystone.create_trust(ctxt)
@@ -617,7 +671,9 @@ class ConductorServerEndpoint(object):
                     ctxt, task.id, constants.TASK_STATUS_PENDING)
                     ctxt, task.id, constants.TASK_STATUS_PENDING)
                 try:
                 try:
                     worker_rpc = self._get_worker_service_rpc_for_task(
                     worker_rpc = self._get_worker_service_rpc_for_task(
-                        ctxt, task, origin_endpoint, destination_endpoint)
+                        ctxt, task, origin_endpoint, destination_endpoint,
+                        retry_count=scheduling_retry_count,
+                        retry_period=scheduling_retry_period)
                     worker_rpc.begin_task(
                     worker_rpc.begin_task(
                         ctxt, server=None,
                         ctxt, server=None,
                         task_id=task.id,
                         task_id=task.id,
@@ -995,7 +1051,7 @@ class ConductorServerEndpoint(object):
 
 
         self._check_execution_tasks_sanity(execution, replica.info)
         self._check_execution_tasks_sanity(execution, replica.info)
 
 
-        # update the action info for all of the Replicas' instnaces:
+        # update the action info for all of the Replicas' instances:
         for instance in replica.instances:
         for instance in replica.instances:
             db_api.update_transfer_action_info_for_instance(
             db_api.update_transfer_action_info_for_instance(
                 ctxt, replica.id, instance, replica.info[instance])
                 ctxt, replica.id, instance, replica.info[instance])
@@ -1577,9 +1633,7 @@ class ConductorServerEndpoint(object):
                         task.status, task.id, execution.id)
                         task.status, task.id, execution.id)
                     db_api.set_task_status(
                     db_api.set_task_status(
                         ctxt, task.id, constants.TASK_STATUS_CANCELLING)
                         ctxt, task.id, constants.TASK_STATUS_CANCELLING)
-                    worker_rpc = self._get_rpc_client_for_service(
-                        self._get_any_worker_service(ctxt))
-                    # TODO(aznashwan): cancel on right worker:
+                    worker_rpc = self._get_worker_rpc_for_host(task.host)
                     worker_rpc.cancel_task(
                     worker_rpc.cancel_task(
                         ctxt, task.host, task.id, task.process_id, force)
                         ctxt, task.host, task.id, task.process_id, force)
                 # let any on-error tasks run to completion but mark
                 # let any on-error tasks run to completion but mark
@@ -1661,10 +1715,16 @@ class ConductorServerEndpoint(object):
                 "Task with ID '%s' is in '%s' status instead of the "
                 "Task with ID '%s' is in '%s' status instead of the "
                 "expected '%s' required for it to have a task host set." % (
                 "expected '%s' required for it to have a task host set." % (
                     task_id, task.status, constants.TASK_STATUS_PENDING))
                     task_id, task.status, constants.TASK_STATUS_PENDING))
+        LOG.info(
+            "Setting host/process for task with ID '%s' to '%s/%s'",
+            task_id, host, process_id)
         db_api.set_task_host(ctxt, task_id, host, process_id)
         db_api.set_task_host(ctxt, task_id, host, process_id)
         db_api.set_task_status(
         db_api.set_task_status(
             ctxt, task_id, new_status,
             ctxt, task_id, new_status,
             exception_details=exception_details)
             exception_details=exception_details)
+        LOG.info(
+            "Successfully set host/process for task with ID '%s' to '%s/%s'",
+            task_id, host, process_id)
 
 
     def _check_clean_execution_deadlock(
     def _check_clean_execution_deadlock(
             self, ctxt, execution, task_statuses=None, requery=True):
             self, ctxt, execution, task_statuses=None, requery=True):
@@ -1861,6 +1921,10 @@ class ConductorServerEndpoint(object):
                     destination=destination,
                     destination=destination,
                     instance=task.instance,
                     instance=task.instance,
                     task_info=task_info)
                     task_info=task_info)
+                LOG.debug(
+                    "Successfully started task with ID '%s' (type '%s') "
+                    "for execution '%s'", task.id, task.task_type,
+                    execution.id)
             except Exception as ex:
             except Exception as ex:
                 msg = (
                 msg = (
                     "Error occured while starting new task '%s'. "
                     "Error occured while starting new task '%s'. "
@@ -2683,13 +2747,14 @@ class ConductorServerEndpoint(object):
         db_api.delete_region(ctxt, region_id)
         db_api.delete_region(ctxt, region_id)
 
 
     def register_service(
     def register_service(
-            self, ctxt, host, binary, topic, enabled, mapped_regions=None):
-        exists = db_api.find_service(ctxt, host, binary, topic=topic)
-        if exists:
+            self, ctxt, host, binary, topic, enabled, mapped_regions=None,
+            providers=None, specs=None):
+        service = db_api.find_service(ctxt, host, binary, topic=topic)
+        if service:
             raise exception.Conflict(
             raise exception.Conflict(
                 "A Service with the specified parameters (host %s, binary %s, "
                 "A Service with the specified parameters (host %s, binary %s, "
                 "topic %s) has already been registered under ID: %s" % (
                 "topic %s) has already been registered under ID: %s" % (
-                    host, binary, topic, exists.id))
+                    host, binary, topic, service.id))
 
 
         service = models.Service()
         service = models.Service()
         service.id = str(uuid.uuid4())
         service.id = str(uuid.uuid4())
@@ -2697,12 +2762,17 @@ class ConductorServerEndpoint(object):
         service.binary = binary
         service.binary = binary
         service.enabled = enabled
         service.enabled = enabled
         service.topic = topic
         service.topic = topic
+        service.status = constants.SERVICE_STATUS_UP
 
 
-        worker_rpc = self._get_rpc_client_for_service(service)
-        status = worker_rpc.get_service_status(ctxt)
+        if None in (providers, specs):
+            worker_rpc = self._get_rpc_client_for_service(service)
+            status = worker_rpc.get_service_status(ctxt)
 
 
-        service.providers = status["providers"]
-        service.specs = status["specs"]
+            service.providers = status["providers"]
+            service.specs = status["specs"]
+        else:
+            service.providers = providers
+            service.specs = specs
 
 
         # create the service:
         # create the service:
         db_api.add_service(ctxt, service)
         db_api.add_service(ctxt, service)
@@ -2726,6 +2796,34 @@ class ConductorServerEndpoint(object):
 
 
         return self.get_service(ctxt, service.id)
         return self.get_service(ctxt, service.id)
 
 
+    def check_service_registered(self, ctxt, host, binary, topic):
+        props = "host='%s', binary='%s', topic='%s'" % (host, binary, topic)
+        LOG.debug(
+            "Checking for existence of service with properties: %s", props)
+        service = db_api.find_service(ctxt, host, binary, topic=topic)
+        if service:
+            LOG.debug(
+                "Found service '%s' for properties %s", service.id, props)
+        else:
+            LOG.debug(
+                "Could not find any service with the specified "
+                "properties: %s", props)
+        return service
+
+    @service_synchronized
+    def refresh_service_status(self, ctxt, service_id):
+        LOG.debug("Updating registration for worker service '%s'", service_id)
+        service = db_api.get_service(ctxt, service_id)
+        worker_rpc = self._get_rpc_client_for_service(service)
+        status = worker_rpc.get_service_status(ctxt)
+        updated_values = {
+            "providers": status["providers"],
+            "specs": status["specs"],
+            "status": constants.SERVICE_STATUS_UP}
+        db_api.update_service(ctxt, service_id, updated_values)
+        LOG.debug("Successfully refreshed status of service '%s'", service_id)
+        return db_api.get_service(ctxt, service_id)
+
     def get_services(self, ctxt):
     def get_services(self, ctxt):
         return db_api.get_services(ctxt)
         return db_api.get_services(ctxt)
 
 

+ 54 - 0
coriolis/conductor/rpc/utils.py

@@ -0,0 +1,54 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import time
+
+from oslo_log import log as logging
+
+from coriolis import utils
+
+
+LOG = logging.getLogger(__name__)
+
+
+def check_create_registration_for_service(
+        conductor_rpc, request_context, host, binary, topic, enabled=False,
+        mapped_regions=None, providers=None, specs=None, retry_period=30):
+    """ Checks with the conductor whether or not a service has already been
+    registered for this host and topic and creates one if not.
+    If the service is already registered, directs the conductor to refresh the
+    service status.
+    """
+    props = "host='%s', binary='%s', topic='%s'" % (host, binary, topic)
+    while True:
+        try:
+            # check is service already exists:
+            LOG.info(
+                "Checking with conductor if service with following porperties "
+                "was already registered: %s", props)
+            worker_service = conductor_rpc.check_service_registered(
+                request_context, host, binary, topic)
+            if worker_service:
+                LOG.info(
+                    "A service with properties %s has already been registered "
+                    "under ID '%s'. Updating existing registration.",
+                    props, worker_service['id'])
+                worker_service = conductor_rpc.update_service(
+                    request_context, worker_service['id'], updated_values={
+                        "providers": providers,
+                        "specs": specs})
+            else:
+                LOG.debug(
+                    "Attempting to register new service with properties: %s",
+                    props)
+                worker_service = conductor_rpc.register_service(
+                    request_context, host, binary, topic, enabled,
+                    mapped_regions=mapped_regions, providers=providers,
+                    specs=specs)
+            return worker_service
+        except Exception as ex:
+            LOG.warn(
+                "Failed to register service with specs %s. Retrying again in "
+                "%d seconds. Error was: %s", props, retry_period,
+                utils.get_exception_details())
+            time.sleep(retry_period)

+ 20 - 16
coriolis/db/api.py

@@ -124,7 +124,9 @@ def _soft_delete_aware_query(context, *args, **kwargs):
     :param show_deleted: if True, overrides context's show_deleted field.
     :param show_deleted: if True, overrides context's show_deleted field.
     """
     """
     query = _model_query(context, *args)
     query = _model_query(context, *args)
-    show_deleted = kwargs.get('show_deleted') or context.show_deleted
+    show_deleted = kwargs.get('show_deleted')
+    if context and context.show_deleted:
+        show_deleted = True
 
 
     if not show_deleted:
     if not show_deleted:
         query = query.filter_by(deleted_at=None)
         query = query.filter_by(deleted_at=None)
@@ -156,7 +158,7 @@ def get_endpoint(context, endpoint_id):
 def add_endpoint(context, endpoint):
 def add_endpoint(context, endpoint):
     endpoint.user_id = context.user
     endpoint.user_id = context.user
     endpoint.project_id = context.tenant
     endpoint.project_id = context.tenant
-    context.session.add(endpoint)
+    _session(context).add(endpoint)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -300,7 +302,7 @@ def add_replica_tasks_execution(context, execution):
             action_id=execution.action.id).first()[0] or 0
             action_id=execution.action.id).first()[0] or 0
     execution.number = max_number + 1
     execution.number = max_number + 1
 
 
-    context.session.add(execution)
+    _session(context).add(execution)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -387,7 +389,7 @@ def add_replica_schedule(context, schedule, post_create_callable=None):
 
 
     if schedule.replica.project_id != context.tenant:
     if schedule.replica.project_id != context.tenant:
         raise exception.NotAuthorized()
         raise exception.NotAuthorized()
-    context.session.add(schedule)
+    _session(context).add(schedule)
     if post_create_callable:
     if post_create_callable:
         post_create_callable(context, schedule)
         post_create_callable(context, schedule)
 
 
@@ -444,7 +446,7 @@ def get_endpoint_replicas_count(context, endpoint_id):
 def add_replica(context, replica):
 def add_replica(context, replica):
     replica.user_id = context.user
     replica.user_id = context.user
     replica.project_id = context.tenant
     replica.project_id = context.tenant
-    context.session.add(replica)
+    _session(context).add(replica)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -533,7 +535,7 @@ def get_migration(context, migration_id):
 def add_migration(context, migration):
 def add_migration(context, migration):
     migration.user_id = context.user
     migration.user_id = context.user
     migration.project_id = context.tenant
     migration.project_id = context.tenant
-    context.session.add(migration)
+    _session(context).add(migration)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -682,7 +684,7 @@ def add_task_event(context, task_id, level, message):
     task_event.task_id = task_id
     task_event.task_id = task_id
     task_event.level = level
     task_event.level = level
     task_event.message = message
     task_event.message = message
-    context.session.add(task_event)
+    _session(context).add(task_event)
 
 
 
 
 def _get_progress_update(context, task_id, current_step):
 def _get_progress_update(context, task_id, current_step):
@@ -698,7 +700,7 @@ def add_task_progress_update(context, task_id, current_step, total_steps,
     task_progress_update = _get_progress_update(context, task_id, current_step)
     task_progress_update = _get_progress_update(context, task_id, current_step)
     if not task_progress_update:
     if not task_progress_update:
         task_progress_update = models.TaskProgressUpdate()
         task_progress_update = models.TaskProgressUpdate()
-        context.session.add(task_progress_update)
+        _session(context).add(task_progress_update)
 
 
     task_progress_update.task_id = task_id
     task_progress_update.task_id = task_id
     task_progress_update.current_step = current_step
     task_progress_update.current_step = current_step
@@ -744,20 +746,22 @@ def update_replica(context, replica_id, updated_values):
 
 
 @enginefacade.writer
 @enginefacade.writer
 def add_region(context, region):
 def add_region(context, region):
-    context.session.add(region)
+    _session(context).add(region)
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
 def get_regions(context):
 def get_regions(context):
-    q = _soft_delete_aware_query(context, models.Region).options(
-        orm.joinedload('mapped_endpoints'))
+    q = _soft_delete_aware_query(context, models.Region)
+    q = q.options(orm.joinedload('mapped_endpoints'))
+    q = q.options(orm.joinedload('mapped_services'))
     return q.all()
     return q.all()
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
 def get_region(context, region_id):
 def get_region(context, region_id):
-    q = _soft_delete_aware_query(context, models.Region).options(
-        orm.joinedload('mapped_endpoints'))
+    q = _soft_delete_aware_query(context, models.Region)
+    q = q.options(orm.joinedload('mapped_endpoints'))
+    q = q.options(orm.joinedload('mapped_services'))
     return q.filter(
     return q.filter(
         models.Region.id == region_id).first()
         models.Region.id == region_id).first()
 
 
@@ -796,7 +800,7 @@ def add_endpoint_region_mapping(context, endpoint_region_mapping):
             "('%s') and the endpoint ID ('%s') must both be non-null." % (
             "('%s') and the endpoint ID ('%s') must both be non-null." % (
                 region_id, endpoint_id))
                 region_id, endpoint_id))
 
 
-    context.session.add(endpoint_region_mapping)
+    _session(context).add(endpoint_region_mapping)
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
@@ -848,7 +852,7 @@ def get_mapped_endpoints_for_region(context, region_id):
 
 
 @enginefacade.writer
 @enginefacade.writer
 def add_service(context, service):
 def add_service(context, service):
-    context.session.add(service)
+    _session(context).add(service)
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
@@ -992,7 +996,7 @@ def add_service_region_mapping(context, service_region_mapping):
             "('%s') and the service ID ('%s') must both be non-null." % (
             "('%s') and the service ID ('%s') must both be non-null." % (
                 region_id, service_id))
                 region_id, service_id))
 
 
-    context.session.add(service_region_mapping)
+    _session(context).add(service_region_mapping)
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader

+ 8 - 16
coriolis/db/sqlalchemy/models.py

@@ -327,10 +327,8 @@ class Service(BASE, models.TimestampMixin, models.ModelBase,
     providers = sqlalchemy.Column(types.Json(), nullable=True)
     providers = sqlalchemy.Column(types.Json(), nullable=True)
     specs = sqlalchemy.Column(types.Json(), nullable=True)
     specs = sqlalchemy.Column(types.Json(), nullable=True)
     mapped_regions = orm.relationship(
     mapped_regions = orm.relationship(
-        ServiceRegionMapping, backref=orm.backref('service'),
-        cascade="all,delete",
-        primaryjoin="and_(ServiceRegionMapping.service_id==Service.id, "
-                    "ServiceRegionMapping.deleted=='0')")
+        'Region', back_populates='mapped_services',
+        secondary="service_region_mapping")
 
 
 
 
 class EndpointRegionMapping(
 class EndpointRegionMapping(
@@ -378,16 +376,12 @@ class Region(
         nullable=False)
         nullable=False)
 
 
     mapped_endpoints = orm.relationship(
     mapped_endpoints = orm.relationship(
-        EndpointRegionMapping, backref=orm.backref('region'),
-        cascade="all,delete",
-        primaryjoin="and_(EndpointRegionMapping.region_id==Region.id, "
-                    "EndpointRegionMapping.deleted=='0')")
+        'Endpoint', back_populates='mapped_regions',
+        secondary="endpoint_region_mapping")
 
 
     mapped_services = orm.relationship(
     mapped_services = orm.relationship(
-        ServiceRegionMapping, backref=orm.backref('region'),
-        cascade="all,delete",
-        primaryjoin="and_(ServiceRegionMapping.region_id==Region.id, "
-                    "ServiceRegionMapping.deleted=='0')")
+        'Service', back_populates='mapped_regions',
+        secondary="service_region_mapping")
 
 
 
 
 class Endpoint(BASE, models.TimestampMixin, models.ModelBase,
 class Endpoint(BASE, models.TimestampMixin, models.ModelBase,
@@ -412,10 +406,8 @@ class Endpoint(BASE, models.TimestampMixin, models.ModelBase,
         primaryjoin="and_(BaseTransferAction.destination_endpoint_id=="
         primaryjoin="and_(BaseTransferAction.destination_endpoint_id=="
                     "Endpoint.id, BaseTransferAction.deleted=='0')")
                     "Endpoint.id, BaseTransferAction.deleted=='0')")
     mapped_regions = orm.relationship(
     mapped_regions = orm.relationship(
-        EndpointRegionMapping, backref=orm.backref('endpoint'),
-        cascade="all,delete",
-        primaryjoin="and_(EndpointRegionMapping.endpoint_id==Endpoint.id, "
-                    "EndpointRegionMapping.deleted=='0')")
+        'Region', back_populates='mapped_endpoints',
+        secondary="endpoint_region_mapping")
 
 
 
 
 class ReplicaSchedule(BASE, models.TimestampMixin, models.ModelBase,
 class ReplicaSchedule(BASE, models.TimestampMixin, models.ModelBase,

+ 19 - 0
coriolis/exception.py

@@ -148,6 +148,7 @@ class Invalid(CoriolisException):
     code = 400
     code = 400
     safe = True
     safe = True
 
 
+
 class InvalidCustomOSDetectTools(Invalid):
 class InvalidCustomOSDetectTools(Invalid):
     message = _("The provided custom OS detect tools are invalid.")
     message = _("The provided custom OS detect tools are invalid.")
 
 
@@ -245,6 +246,10 @@ class NotFound(CoriolisException):
     safe = True
     safe = True
 
 
 
 
+class RegionNotFound(NotFound):
+    message = _("The specified Coriolis region(s) could not be found.")
+
+
 class OSMorphingToolsNotFound(NotFound):
 class OSMorphingToolsNotFound(NotFound):
     message = _(
     message = _(
         'No OSMorphing tools were found for OS type "%(os_type)s" for this VM.'
         'No OSMorphing tools were found for OS type "%(os_type)s" for this VM.'
@@ -390,6 +395,20 @@ class UnrecognizedWorkerInitSystem(CoriolisException):
         "Coriolis to be able to use it for data Replication.")
         "Coriolis to be able to use it for data Replication.")
 
 
 
 
+class NoRegionError(CoriolisException):
+    safe = True
+    code = 503
+    message = _(
+        "No Coriolis region is avaialable to process this request at this "
+        "time.")
+
+
+class NoSuitableRegionError(NoRegionError):
+    message = _(
+        "No Coriolis Region(s) fitting the criteria of the required operation "
+        "could be found.")
+
+
 class NoServiceError(CoriolisException):
 class NoServiceError(CoriolisException):
     safe = True
     safe = True
     code = 503
     code = 503

+ 26 - 11
coriolis/scheduler/filters/trivial_filters.py

@@ -12,25 +12,40 @@ LOG = logging.getLogger(__name__)
 
 
 class RegionsFilter(base.BaseServiceFilter):
 class RegionsFilter(base.BaseServiceFilter):
 
 
-    def __init__(self, regions):
+    def __init__(self, regions, any_region=False):
         self._regions = regions
         self._regions = regions
+        self._any_region = any_region
 
 
     def __repr__(self):
     def __repr__(self):
-        return "<%s(regions=%s)>" % (
-            self.__class__.__name__, self._regions)
+        return "<%s(regions=%s, any_region=%s)>" % (
+            self.__class__.__name__, self._regions, self._any_region)
 
 
     def rate_service(self, service):
     def rate_service(self, service):
-        service_regions = [
-            mapping["region_id"] for mapping in service.mapped_regions]
-        missing_regions = [
-            region
-            for region in self._regions
-            if region not in service_regions]
+        if not self._regions:
+            LOG.debug(
+                "No regions specified for this filter (%s). "
+                "Presuming service is valid.")
+            return 100
 
 
-        if missing_regions:
+        service_regions = [
+            region.id for region in service.mapped_regions]
+        found = []
+        missing = []
+        for region in self._regions:
+            if region in service_regions:
+                found.append(region)
+            else:
+                missing.append(region)
+
+        if not found:
+            LOG.debug(
+                "None of the requested regions are available on service (%s): "
+                "%s", service.id, self._regions)
+            return 0
+        if not self._any_region and missing:
             LOG.debug(
             LOG.debug(
                 "The following required regions are missing from service "
                 "The following required regions are missing from service "
-                "with ID '%s': %s", service.id, missing_regions)
+                "with ID '%s': %s", service.id, missing)
             return 0
             return 0
 
 
         return 100
         return 100

+ 2 - 21
coriolis/scheduler/rpc/client.py

@@ -30,26 +30,7 @@ class SchedulerClient(object):
 
 
     def get_workers_for_specs(
     def get_workers_for_specs(
             self, ctxt, provider_requirements=None,
             self, ctxt, provider_requirements=None,
-            region_ids=None, enabled=None):
+            region_sets=None, enabled=None):
         return self._client.call(
         return self._client.call(
-            ctxt, 'get_workers_for_specs', region_ids=region_ids,
+            ctxt, 'get_workers_for_specs', region_sets=region_sets,
             enabled=enabled, provider_requirements=provider_requirements)
             enabled=enabled, provider_requirements=provider_requirements)
-
-    '''
-    def get_workers_for_action(
-            self, ctxt, endpoint_type, provider_type, region_ids=None):
-        return self._client.call(
-            ctxt, 'get_workers_for_action', endpoint_type=endpoint_type,
-            provider_type=provider_type, region_ids=region_ids)
-
-    def get_workers_for_task(
-            self, ctxt, task_type, source_endpoint_type,
-            destination_endpoint_type, source_region_ids=None,
-            destination_region_ids=None):
-        return self._client.call(
-            ctxt, 'get_workers_for_task', task_type=task_type,
-            source_endpoint_type=source_endpoint_type,
-            destination_endpoint_type=destination_endpoint_type,
-            source_region_ids=source_region_ids,
-            destination_region_ids=destination_region_ids)
-    '''

+ 43 - 32
coriolis/scheduler/rpc/server.py

@@ -97,13 +97,42 @@ class SchedulerServerEndpoint(object):
         return sorted(
         return sorted(
             scores, key=lambda s: s[1], reverse=True)
             scores, key=lambda s: s[1], reverse=True)
 
 
+    def _filter_regions(
+            self, ctxt, region_ids, enabled=True, check_all_exist=True,
+            regions_cache=None):
+        found_regions = []
+        filtered_regions = []
+        regions = regions_cache
+        if not regions:
+            regions = db_api.get_regions(ctxt)
+        for region in regions:
+            if region.id in region_ids:
+                found_regions.append(region.id)
+                if region.enabled != enabled:
+                    continue
+                filtered_regions.append(region)
+
+        if check_all_exist:
+            missing_regions = set(region_ids).difference(
+                set(found_regions))
+            if missing_regions:
+                raise exception.RegionNotFound(
+                    "Failed to schedule job on regions %s as one or more "
+                    "of the proposed regions (%s) do not exist." % (
+                        region_ids, missing_regions))
+
+        return filtered_regions
+
     def get_workers_for_specs(
     def get_workers_for_specs(
             self, ctxt, provider_requirements=None,
             self, ctxt, provider_requirements=None,
-            region_ids=None, enabled=None):
+            region_sets=None, enabled=None, filter_disabled_regions=True):
         """ Returns a list of enabled Worker Services with the specified
         """ Returns a list of enabled Worker Services with the specified
         parameters.
         parameters.
         :param provider_requirements: dict of the form {
         :param provider_requirements: dict of the form {
             "<platform_type>": [constants.PROVIDER_TYPE_*, ...]}
             "<platform_type>": [constants.PROVIDER_TYPE_*, ...]}
+        param region_sets: list of lists of region IDs to filter for.
+        Services will be filtered unless they are associated with
+        at least one region in each region set.
         """
         """
         filters = []
         filters = []
         worker_services = self._get_all_worker_services(ctxt)
         worker_services = self._get_all_worker_services(ctxt)
@@ -111,12 +140,21 @@ class SchedulerServerEndpoint(object):
         LOG.debug(
         LOG.debug(
             "Searching for Worker Services with specs: %s" % {
             "Searching for Worker Services with specs: %s" % {
                 "provider_requirements": provider_requirements,
                 "provider_requirements": provider_requirements,
-                "region_ids": region_ids, "enabled": enabled})
+                "region_sets": region_sets, "enabled": enabled})
 
 
         if enabled is not None:
         if enabled is not None:
             filters.append(trivial_filters.EnabledFilter(enabled=enabled))
             filters.append(trivial_filters.EnabledFilter(enabled=enabled))
-        if region_ids:
-            filters.append(trivial_filters.RegionsFilter(region_ids))
+        if region_sets:
+            for region_set in region_sets:
+                filtered_regions = self._filter_regions(
+                    ctxt, region_set, enabled=filter_disabled_regions,
+                    check_all_exist=True)
+                if not filtered_regions:
+                    raise exception.NoSuitableRegionError(
+                        "None of the selected Regions (%s) are enabled or "
+                        "otherwise usable." % region_set)
+                filters.append(
+                    trivial_filters.RegionsFilter(region_set, any_region=True))
         if provider_requirements:
         if provider_requirements:
             filters.append(
             filters.append(
                 trivial_filters.ProviderTypesFilter(provider_requirements))
                 trivial_filters.ProviderTypesFilter(provider_requirements))
@@ -127,33 +165,6 @@ class SchedulerServerEndpoint(object):
             "Found Worker Services %s for specs: %s" % (
             "Found Worker Services %s for specs: %s" % (
                 filtered_services, {
                 filtered_services, {
                     "provider_requirements": provider_requirements,
                     "provider_requirements": provider_requirements,
-                    "region_ids": region_ids, "enabled": enabled}))
+                    "region_sets": region_sets, "enabled": enabled}))
 
 
         return [s[0] for s in filtered_services]
         return [s[0] for s in filtered_services]
-
-    '''
-    def get_workers_for_action(
-            self, ctxt, endpoint_type, provider_type, region_ids=None):
-        """ Returns a list of worker services which would be able to
-        perform the required provider-related action.
-
-        :param endpoint_type: the type of the platform of the endpoint
-        :param provider_type: the type of the required plugin as defined
-                              in constants.PROVIDER_TYPE_*
-        :param region_ids: list of region IDs to carry out the operation in.
-                           Will be carried out ar random if not provided.
-        """
-        # TODO
-        return self._get_all_worker_services(ctxt)
-
-    def get_workers_for_task(
-            self, ctxt, task_type, source_endpoint_type,
-            destination_endpoint_type,
-            source_region_ids=None,
-            destination_region_ids=None):
-        """ Returns a list of worker services which would be
-        able and willing to accomplish the given task.
-        """
-        # TODO
-        return self._get_all_worker_services(ctxt)
-    '''

+ 52 - 1
coriolis/worker/rpc/server.py

@@ -16,7 +16,9 @@ import psutil
 from six.moves import queue
 from six.moves import queue
 
 
 from coriolis.conductor.rpc import client as rpc_conductor_client
 from coriolis.conductor.rpc import client as rpc_conductor_client
+from coriolis.conductor.rpc import utils as conductor_rpc_utils
 from coriolis import constants
 from coriolis import constants
+from coriolis import context
 from coriolis import events
 from coriolis import events
 from coriolis import exception
 from coriolis import exception
 from coriolis.providers import factory as providers_factory
 from coriolis.providers import factory as providers_factory
@@ -63,7 +65,36 @@ class _ConductorProviderEventHandler(events.BaseEventHandler):
 class WorkerServerEndpoint(object):
 class WorkerServerEndpoint(object):
     def __init__(self):
     def __init__(self):
         self._server = utils.get_hostname()
         self._server = utils.get_hostname()
-        self._rpc_conductor_client = rpc_conductor_client.ConductorClient()
+        self._service_registration = self._register_worker_service()
+
+    @property
+    def _rpc_conductor_client(self):
+        # NOTE(aznashwan): it is unsafe to fork processes with pre-instantiated
+        # oslo_messaging clients as the underlying eventlet thread queues will
+        # be invalidated. Considering this class both serves from a "main
+        # process" as well as forking child processes, it is safest to
+        # re-instantiate the client every time:
+        return rpc_conductor_client.ConductorClient()
+
+    def _register_worker_service(self):
+        host = utils.get_hostname()
+        binary = utils.get_binary_name()
+        dummy_context = context.RequestContext(
+            # TODO(aznashwan): we should ideally have a dedicated
+            # user/pass/tenant just for service registration.
+            # Either way, these values are not used and thus redundant.
+            "coriolis", "admin")
+        status = self.get_service_status(dummy_context)
+        service_registration = (
+            conductor_rpc_utils.check_create_registration_for_service(
+                self._rpc_conductor_client, dummy_context, host, binary,
+                constants.WORKER_MAIN_MESSAGING_TOPIC, enabled=True,
+                providers=status['providers'], specs=status['specs']))
+        LOG.info(
+            "Worker service is successfully registered with the following "
+            "parameters: %s", service_registration)
+        self._service_registration = service_registration
+        return service_registration
 
 
     def _check_remove_dir(self, path):
     def _check_remove_dir(self, path):
         try:
         try:
@@ -196,10 +227,15 @@ class WorkerServerEndpoint(object):
 
 
         self._start_process_with_custom_library_paths(p, extra_library_paths)
         self._start_process_with_custom_library_paths(p, extra_library_paths)
         LOG.info("Task process started: %s", task_id)
         LOG.info("Task process started: %s", task_id)
+        LOG.debug(
+            "Attempting to set task host on Conductor for task '%s'.", task_id)
         try:
         try:
             self._rpc_conductor_client.set_task_host(
             self._rpc_conductor_client.set_task_host(
                 ctxt, task_id, self._server, p.pid)
                 ctxt, task_id, self._server, p.pid)
         except (Exception, KeyboardInterrupt) as ex:
         except (Exception, KeyboardInterrupt) as ex:
+            LOG.debug(
+                "Exception occurred whilst setting host for task '%s'. Error "
+                "was: %s", task_id, utils.get_exception_details())
             # NOTE: because the task error classes are wrapped,
             # NOTE: because the task error classes are wrapped,
             # it's easiest to just check that the messages align:
             # it's easiest to just check that the messages align:
             cancelling_msg = exception.TASK_ALREADY_CANCELLING_EXCEPTION_FMT % {
             cancelling_msg = exception.TASK_ALREADY_CANCELLING_EXCEPTION_FMT % {
@@ -216,10 +252,17 @@ class WorkerServerEndpoint(object):
         p.join()
         p.join()
 
 
         if result is None:
         if result is None:
+            LOG.debug(
+                "No result from process (%s) running task '%s'. "
+                "Presuming task was cancelled.",
+                p.pid, task_id)
             raise exception.TaskProcessCanceledException(
             raise exception.TaskProcessCanceledException(
                 "Task was canceled.")
                 "Task was canceled.")
 
 
         if isinstance(result, str):
         if isinstance(result, str):
+            LOG.debug(
+                "Error message while running task '%s' on process "
+                "with PID '%s': %s", task_id, p.pid, result)
             raise exception.TaskProcessException(result)
             raise exception.TaskProcessException(result)
         return result
         return result
 
 
@@ -238,10 +281,18 @@ class WorkerServerEndpoint(object):
             self._rpc_conductor_client.task_completed(
             self._rpc_conductor_client.task_completed(
                 ctxt, task_id, task_result)
                 ctxt, task_id, task_result)
         except exception.TaskProcessCanceledException as ex:
         except exception.TaskProcessCanceledException as ex:
+            LOG.debug(
+                "Task with ID '%s' appears to have been cancelled. "
+                "Confirming cancellation to Conductor now. Error was: %s",
+                task_id, utils.get_exception_details())
             LOG.exception(ex)
             LOG.exception(ex)
             self._rpc_conductor_client.confirm_task_cancellation(
             self._rpc_conductor_client.confirm_task_cancellation(
                 ctxt, task_id, str(ex))
                 ctxt, task_id, str(ex))
         except Exception as ex:
         except Exception as ex:
+            LOG.debug(
+                "Task with ID '%s' has error'd out. Reporting error to "
+                "Conductor now. Error was: %s",
+                task_id, utils.get_exception_details())
             LOG.exception(ex)
             LOG.exception(ex)
             self._rpc_conductor_client.set_task_error(ctxt, task_id, str(ex))
             self._rpc_conductor_client.set_task_error(ctxt, task_id, str(ex))