Browse Source

Merge pull request #139 from aznashwan/distributed-worker-services

Add Scheduler component and Service/Region management.
Nashwan Azhari 5 years ago
parent
commit
6c06d62905
44 changed files with 2877 additions and 245 deletions
  1. 12 5
      coriolis/api/v1/endpoints.py
  2. 90 0
      coriolis/api/v1/regions.py
  3. 12 0
      coriolis/api/v1/router.py
  4. 93 0
      coriolis/api/v1/services.py
  5. 6 1
      coriolis/api/v1/views/endpoint_view.py
  6. 34 0
      coriolis/api/v1/views/region_view.py
  7. 30 0
      coriolis/api/v1/views/service_view.py
  8. 3 1
      coriolis/cmd/conductor.py
  9. 3 2
      coriolis/cmd/replica_cron.py
  10. 31 0
      coriolis/cmd/scheduler.py
  11. 3 1
      coriolis/cmd/worker.py
  12. 71 7
      coriolis/conductor/rpc/client.py
  13. 582 70
      coriolis/conductor/rpc/server.py
  14. 54 0
      coriolis/conductor/rpc/utils.py
  15. 28 2
      coriolis/constants.py
  16. 476 23
      coriolis/db/api.py
  17. 124 0
      coriolis/db/sqlalchemy/migrate_repo/versions/014_adds_worker_service_regions.py
  18. 112 3
      coriolis/db/sqlalchemy/models.py
  19. 3 2
      coriolis/endpoints/api.py
  20. 39 0
      coriolis/exception.py
  21. 79 0
      coriolis/policies/regions.py
  22. 79 0
      coriolis/policies/services.py
  23. 4 2
      coriolis/policy.py
  24. 1 1
      coriolis/providers/factory.py
  25. 0 0
      coriolis/regions/__init__.py
  26. 27 0
      coriolis/regions/api.py
  27. 3 2
      coriolis/replica_cron/rpc/client.py
  28. 0 0
      coriolis/scheduler/__init__.py
  29. 0 0
      coriolis/scheduler/filters/__init__.py
  30. 22 0
      coriolis/scheduler/filters/base.py
  31. 117 0
      coriolis/scheduler/filters/trivial_filters.py
  32. 0 0
      coriolis/scheduler/rpc/__init__.py
  33. 36 0
      coriolis/scheduler/rpc/client.py
  34. 170 0
      coriolis/scheduler/rpc/server.py
  35. 0 0
      coriolis/services/__init__.py
  36. 29 0
      coriolis/services/api.py
  37. 23 10
      coriolis/tasks/base.py
  38. 15 4
      coriolis/tasks/migration_tasks.py
  39. 47 12
      coriolis/tasks/osmorphing_tasks.py
  40. 317 84
      coriolis/tasks/replica_tasks.py
  41. 7 2
      coriolis/utils.py
  42. 7 5
      coriolis/worker/rpc/client.py
  43. 87 6
      coriolis/worker/rpc/server.py
  44. 1 0
      setup.cfg

+ 12 - 5
coriolis/api/v1/endpoints.py

@@ -40,7 +40,10 @@ class EndpointController(api_wsgi.Controller):
             description = endpoint.get("description")
             description = endpoint.get("description")
             endpoint_type = endpoint["type"]
             endpoint_type = endpoint["type"]
             connection_info = endpoint["connection_info"]
             connection_info = endpoint["connection_info"]
-            return name, endpoint_type, description, connection_info
+            mapped_regions = endpoint.get("mapped_regions", [])
+            return (
+                name, endpoint_type, description, connection_info,
+                mapped_regions)
         except Exception as ex:
         except Exception as ex:
             LOG.exception(ex)
             LOG.exception(ex)
             if hasattr(ex, "message"):
             if hasattr(ex, "message"):
@@ -53,15 +56,19 @@ class EndpointController(api_wsgi.Controller):
         context = req.environ["coriolis.context"]
         context = req.environ["coriolis.context"]
         context.can(endpoint_policies.get_endpoints_policy_label("create"))
         context.can(endpoint_policies.get_endpoints_policy_label("create"))
         (name, endpoint_type, description,
         (name, endpoint_type, description,
-         connection_info) = self._validate_create_body(body)
+         connection_info, mapped_regions) = self._validate_create_body(body)
         return endpoint_view.single(req, self._endpoint_api.create(
         return endpoint_view.single(req, self._endpoint_api.create(
-            context, name, endpoint_type, description, connection_info))
+            context, name, endpoint_type, description, connection_info,
+            mapped_regions))
 
 
     def _validate_update_body(self, body):
     def _validate_update_body(self, body):
         try:
         try:
             endpoint = body["endpoint"]
             endpoint = body["endpoint"]
-            return {k: endpoint[k] for k in endpoint.keys() &
-                    {"name", "description", "connection_info"}}
+            return {
+                k: endpoint[k]
+                for k in endpoint.keys() & {
+                    "name", "description", "connection_info",
+                    "mapped_regions"}}
         except Exception as ex:
         except Exception as ex:
             LOG.exception(ex)
             LOG.exception(ex)
             if hasattr(ex, "message"):
             if hasattr(ex, "message"):

+ 90 - 0
coriolis/api/v1/regions.py

@@ -0,0 +1,90 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+from oslo_log import log as logging
+from webob import exc
+
+from coriolis import exception
+from coriolis.api.v1.views import region_view
+from coriolis.api import wsgi as api_wsgi
+from coriolis.policies import regions as region_policies
+from coriolis.regions import api
+
+LOG = logging.getLogger(__name__)
+
+
+class RegionController(api_wsgi.Controller):
+    def __init__(self):
+        self._region_api = api.API()
+        super(RegionController, self).__init__()
+
+    def show(self, req, id):
+        context = req.environ["coriolis.context"]
+        context.can(region_policies.get_regions_policy_label("show"))
+        region = self._region_api.get_region(context, id)
+        if not region:
+            raise exc.HTTPNotFound()
+
+        return region_view.single(req, region)
+
+    def index(self, req):
+        context = req.environ["coriolis.context"]
+        context.can(region_policies.get_regions_policy_label("list"))
+        return region_view.collection(
+            req, self._region_api.get_regions(context))
+
+    def _validate_create_body(self, body):
+        try:
+            region = body["region"]
+            name = region["name"]
+            description = region.get("description", "")
+            enabled = region.get("enabled", True)
+            return name, description, enabled
+        except Exception as ex:
+            LOG.exception(ex)
+            if hasattr(ex, "message"):
+                msg = ex.message
+            else:
+                msg = str(ex)
+            raise exception.InvalidInput(msg)
+
+    def create(self, req, body):
+        context = req.environ["coriolis.context"]
+        context.can(region_policies.get_regions_policy_label("create"))
+        (name, description, enabled) = self._validate_create_body(body)
+        return region_view.single(req, self._region_api.create(
+            context, region_name=name, description=description,
+            enabled=enabled))
+
+    def _validate_update_body(self, body):
+        try:
+            region = body["region"]
+            return {k: region[k] for k in region.keys() &
+                    {"name", "description", "enabled"}}
+        except Exception as ex:
+            LOG.exception(ex)
+            if hasattr(ex, "message"):
+                msg = ex.message
+            else:
+                msg = str(ex)
+            raise exception.InvalidInput(msg)
+
+    def update(self, req, id, body):
+        context = req.environ["coriolis.context"]
+        context.can(region_policies.get_regions_policy_label("update"))
+        updated_values = self._validate_update_body(body)
+        return region_view.single(req, self._region_api.update(
+            req.environ['coriolis.context'], id, updated_values))
+
+    def delete(self, req, id):
+        context = req.environ["coriolis.context"]
+        context.can(region_policies.get_regions_policy_label("delete"))
+        try:
+            self._region_api.delete(req.environ['coriolis.context'], id)
+            raise exc.HTTPNoContent()
+        except exception.NotFound as ex:
+            raise exc.HTTPNotFound(explanation=ex.msg)
+
+
+def create_resource():
+    return api_wsgi.Resource(RegionController())

+ 12 - 0
coriolis/api/v1/router.py

@@ -16,11 +16,13 @@ from coriolis.api.v1 import migration_actions
 from coriolis.api.v1 import migrations
 from coriolis.api.v1 import migrations
 from coriolis.api.v1 import provider_schemas
 from coriolis.api.v1 import provider_schemas
 from coriolis.api.v1 import providers
 from coriolis.api.v1 import providers
+from coriolis.api.v1 import regions
 from coriolis.api.v1 import replica_actions
 from coriolis.api.v1 import replica_actions
 from coriolis.api.v1 import replica_schedules
 from coriolis.api.v1 import replica_schedules
 from coriolis.api.v1 import replica_tasks_execution_actions
 from coriolis.api.v1 import replica_tasks_execution_actions
 from coriolis.api.v1 import replica_tasks_executions
 from coriolis.api.v1 import replica_tasks_executions
 from coriolis.api.v1 import replicas
 from coriolis.api.v1 import replicas
+from coriolis.api.v1 import services
 
 
 LOG = logging.getLogger(__name__)
 LOG = logging.getLogger(__name__)
 
 
@@ -43,12 +45,22 @@ class APIRouter(api.APIRouter):
         mapper.resource('provider', 'providers',
         mapper.resource('provider', 'providers',
                         controller=self.resources['providers'])
                         controller=self.resources['providers'])
 
 
+        self.resources['regions'] = regions.create_resource()
+        mapper.resource('region', 'regions',
+                        controller=self.resources['regions'],
+                        collection={'detail': 'GET'})
+
         self.resources['endpoints'] = endpoints.create_resource()
         self.resources['endpoints'] = endpoints.create_resource()
         mapper.resource('endpoint', 'endpoints',
         mapper.resource('endpoint', 'endpoints',
                         controller=self.resources['endpoints'],
                         controller=self.resources['endpoints'],
                         collection={'detail': 'GET'},
                         collection={'detail': 'GET'},
                         member={'action': 'POST'})
                         member={'action': 'POST'})
 
 
+        self.resources['services'] = services.create_resource()
+        mapper.resource('service', 'services',
+                        controller=self.resources['services'],
+                        collection={'detail': 'GET'})
+
         endpoint_actions_resource = endpoint_actions.create_resource()
         endpoint_actions_resource = endpoint_actions.create_resource()
         self.resources['endpoint_actions'] = endpoint_actions_resource
         self.resources['endpoint_actions'] = endpoint_actions_resource
         endpoint_path = '/{project_id}/endpoints/{id}'
         endpoint_path = '/{project_id}/endpoints/{id}'

+ 93 - 0
coriolis/api/v1/services.py

@@ -0,0 +1,93 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+from oslo_log import log as logging
+from webob import exc
+
+from coriolis import exception
+from coriolis.api.v1.views import service_view
+from coriolis.api import wsgi as api_wsgi
+from coriolis.policies import services as service_policies
+from coriolis.services import api
+
+LOG = logging.getLogger(__name__)
+
+
+class ServiceController(api_wsgi.Controller):
+    def __init__(self):
+        self._service_api = api.API()
+        super(ServiceController, self).__init__()
+
+    def show(self, req, id):
+        context = req.environ["coriolis.context"]
+        context.can(service_policies.get_services_policy_label("show"))
+        service = self._service_api.get_service(context, id)
+        if not service:
+            raise exc.HTTPNotFound()
+
+        return service_view.single(req, service)
+
+    def index(self, req):
+        context = req.environ["coriolis.context"]
+        context.can(service_policies.get_services_policy_label("list"))
+        return service_view.collection(
+            req, self._service_api.get_services(context))
+
+    def _validate_create_body(self, body):
+        try:
+            service = body["service"]
+            host = service["host"]
+            binary = service["binary"]
+            topic = service["topic"]
+            mapped_regions = service.get('mapped_regions', [])
+            enabled = service.get("enabled", True)
+            return host, binary, topic, mapped_regions, enabled
+        except Exception as ex:
+            LOG.exception(ex)
+            if hasattr(ex, "message"):
+                msg = ex.message
+            else:
+                msg = str(ex)
+            raise exception.InvalidInput(msg)
+
+    def create(self, req, body):
+        context = req.environ["coriolis.context"]
+        context.can(service_policies.get_services_policy_label("create"))
+        (host, binary, topic, mapped_regions, enabled) = (
+            self._validate_create_body(body))
+        return service_view.single(req, self._service_api.create(
+            context, host=host, binary=binary, topic=topic,
+            mapped_regions=mapped_regions, enabled=enabled))
+
+    def _validate_update_body(self, body):
+        try:
+            service = body["service"]
+            return {k: service[k] for k in service.keys() & {
+                "enabled", "mapped_regions"}}
+        except Exception as ex:
+            LOG.exception(ex)
+            if hasattr(ex, "message"):
+                msg = ex.message
+            else:
+                msg = str(ex)
+            raise exception.InvalidInput(msg)
+
+    def update(self, req, id, body):
+        context = req.environ["coriolis.context"]
+        context.can(service_policies.get_services_policy_label("update"))
+        updated_values = self._validate_update_body(body)
+        return service_view.single(req, self._service_api.update(
+            req.environ['coriolis.context'], id, updated_values))
+
+    def delete(self, req, id):
+        context = req.environ["coriolis.context"]
+        context.can(service_policies.get_services_policy_label("delete"))
+        try:
+            self._service_api.delete(req.environ['coriolis.context'], id)
+            raise exc.HTTPNoContent()
+        except exception.NotFound as ex:
+            raise exc.HTTPNotFound(explanation=ex.msg)
+
+
+def create_resource():
+    return api_wsgi.Resource(ServiceController())

+ 6 - 1
coriolis/api/v1/views/endpoint_view.py

@@ -10,8 +10,13 @@ def _format_endpoint(req, endpoint, keys=None):
             return
             return
         yield (key, value)
         yield (key, value)
 
 
-    return dict(itertools.chain.from_iterable(
+    endpoint_dict = dict(itertools.chain.from_iterable(
         transform(k, v) for k, v in endpoint.items()))
         transform(k, v) for k, v in endpoint.items()))
+    mapped_regions = endpoint_dict.get('mapped_regions', [])
+    endpoint_dict['mapped_regions'] = [
+        reg['id'] for reg in mapped_regions]
+
+    return endpoint_dict
 
 
 
 
 def single(req, endpoint):
 def single(req, endpoint):

+ 34 - 0
coriolis/api/v1/views/region_view.py

@@ -0,0 +1,34 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import itertools
+
+
+def _format_region(req, region, keys=None):
+    def transform(key, value):
+        if keys and key not in keys:
+            return
+        yield (key, value)
+
+    region_dict = dict(itertools.chain.from_iterable(
+        transform(k, v) for k, v in region.items()))
+
+    mapped_endpoints = region_dict.get('mapped_endpoints', [])
+    region_dict['mapped_endpoints'] = [
+        endp['id'] for endp in mapped_endpoints]
+
+    mapped_services = region_dict.get('mapped_services', [])
+    region_dict['mapped_services'] = [
+        svc['id'] for svc in mapped_services]
+
+    return region_dict
+
+
+def single(req, region):
+    return {"region": _format_region(req, region)}
+
+
+def collection(req, regions):
+    formatted_regions = [
+        _format_region(req, r) for r in regions]
+    return {'regions': formatted_regions}

+ 30 - 0
coriolis/api/v1/views/service_view.py

@@ -0,0 +1,30 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import itertools
+
+
+def _format_service(req, service, keys=None):
+    def transform(key, value):
+        if keys and key not in keys:
+            return
+        yield (key, value)
+
+    service_dict = dict(itertools.chain.from_iterable(
+        transform(k, v) for k, v in service.items()))
+
+    mapped_regions = service_dict.get('mapped_regions', [])
+    service_dict['mapped_regions'] = [
+        mapping['id'] for mapping in mapped_regions]
+
+    return service_dict
+
+
+def single(req, service):
+    return {"service": _format_service(req, service)}
+
+
+def collection(req, services):
+    formatted_services = [
+        _format_service(req, r) for r in services]
+    return {'services': formatted_services}

+ 3 - 1
coriolis/cmd/conductor.py

@@ -5,6 +5,7 @@ import sys
 
 
 from oslo_config import cfg
 from oslo_config import cfg
 
 
+from coriolis import constants
 from coriolis.conductor.rpc import server as rpc_server
 from coriolis.conductor.rpc import server as rpc_server
 from coriolis import service
 from coriolis import service
 from coriolis import utils
 from coriolis import utils
@@ -19,7 +20,8 @@ def main():
     service.check_locks_dir_empty()
     service.check_locks_dir_empty()
 
 
     server = service.MessagingService(
     server = service.MessagingService(
-        'coriolis_conductor', [rpc_server.ConductorServerEndpoint()],
+        constants.CONDUCTOR_MAIN_MESSAGING_TOPIC,
+        [rpc_server.ConductorServerEndpoint()],
         rpc_server.VERSION)
         rpc_server.VERSION)
     launcher = service.service.launch(
     launcher = service.service.launch(
         CONF, server, workers=server.get_workers_count())
         CONF, server, workers=server.get_workers_count())

+ 3 - 2
coriolis/cmd/replica_cron.py

@@ -5,9 +5,10 @@ import sys
 
 
 from oslo_config import cfg
 from oslo_config import cfg
 
 
-from coriolis.replica_cron.rpc import server as rpc_server
+from coriolis import constants
 from coriolis import service
 from coriolis import service
 from coriolis import utils
 from coriolis import utils
+from coriolis.replica_cron.rpc import server as rpc_server
 
 
 CONF = cfg.CONF
 CONF = cfg.CONF
 
 
@@ -18,7 +19,7 @@ def main():
     utils.setup_logging()
     utils.setup_logging()
 
 
     server = service.MessagingService(
     server = service.MessagingService(
-        'coriolis_replica_cron_worker',
+        constants.REPLICA_CRON_MAIN_MESSAGING_TOPIC,
         [rpc_server.ReplicaCronServerEndpoint()],
         [rpc_server.ReplicaCronServerEndpoint()],
         rpc_server.VERSION, worker_count=1)
         rpc_server.VERSION, worker_count=1)
     launcher = service.service.launch(
     launcher = service.service.launch(

+ 31 - 0
coriolis/cmd/scheduler.py

@@ -0,0 +1,31 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import sys
+
+from oslo_config import cfg
+
+from coriolis import constants
+from coriolis import service
+from coriolis import utils
+from coriolis.scheduler.rpc import server as rpc_server
+
+CONF = cfg.CONF
+
+
+def main():
+    CONF(sys.argv[1:], project='coriolis',
+         version="1.0.0")
+    utils.setup_logging()
+
+    server = service.MessagingService(
+        constants.SCHEDULER_MAIN_MESSAGING_TOPIC,
+        [rpc_server.SchedulerServerEndpoint()],
+        rpc_server.VERSION, worker_count=1)
+    launcher = service.service.launch(
+        CONF, server, workers=server.get_workers_count())
+    launcher.wait()
+
+
+if __name__ == "__main__":
+    main()

+ 3 - 1
coriolis/cmd/worker.py

@@ -5,6 +5,7 @@ import sys
 
 
 from oslo_config import cfg
 from oslo_config import cfg
 
 
+from coriolis import constants
 from coriolis import service
 from coriolis import service
 from coriolis import utils
 from coriolis import utils
 from coriolis.worker.rpc import server as rpc_server
 from coriolis.worker.rpc import server as rpc_server
@@ -18,7 +19,8 @@ def main():
     utils.setup_logging()
     utils.setup_logging()
 
 
     server = service.MessagingService(
     server = service.MessagingService(
-        'coriolis_worker', [rpc_server.WorkerServerEndpoint()],
+        constants.WORKER_MAIN_MESSAGING_TOPIC,
+        [rpc_server.WorkerServerEndpoint()],
         rpc_server.VERSION)
         rpc_server.VERSION)
     launcher = service.service.launch(
     launcher = service.service.launch(
         CONF, server, workers=server.get_workers_count())
         CONF, server, workers=server.get_workers_count())

+ 71 - 7
coriolis/conductor/rpc/client.py

@@ -4,6 +4,7 @@
 from oslo_config import cfg
 from oslo_config import cfg
 import oslo_messaging as messaging
 import oslo_messaging as messaging
 
 
+from coriolis import constants
 from coriolis import rpc
 from coriolis import rpc
 
 
 VERSION = "1.0"
 VERSION = "1.0"
@@ -19,17 +20,19 @@ CONF.register_opts(conductor_opts, 'conductor')
 
 
 
 
 class ConductorClient(object):
 class ConductorClient(object):
-    def __init__(self, timeout=None):
-        target = messaging.Target(topic='coriolis_conductor', version=VERSION)
+    def __init__(self, timeout=None,
+                 topic=constants.CONDUCTOR_MAIN_MESSAGING_TOPIC):
+        target = messaging.Target(topic=topic, version=VERSION)
         if timeout is None:
         if timeout is None:
             timeout = CONF.conductor.conductor_rpc_timeout
             timeout = CONF.conductor.conductor_rpc_timeout
         self._client = rpc.get_client(target, timeout=timeout)
         self._client = rpc.get_client(target, timeout=timeout)
 
 
     def create_endpoint(self, ctxt, name, endpoint_type, description,
     def create_endpoint(self, ctxt, name, endpoint_type, description,
-                        connection_info):
+                        connection_info, mapped_regions):
         return self._client.call(
         return self._client.call(
             ctxt, 'create_endpoint', name=name, endpoint_type=endpoint_type,
             ctxt, 'create_endpoint', name=name, endpoint_type=endpoint_type,
-            description=description, connection_info=connection_info)
+            description=description, connection_info=connection_info,
+            mapped_regions=mapped_regions)
 
 
     def update_endpoint(self, ctxt, endpoint_id, updated_values):
     def update_endpoint(self, ctxt, endpoint_id, updated_values):
         return self._client.call(
         return self._client.call(
@@ -231,10 +234,13 @@ class ConductorClient(object):
         self._client.call(
         self._client.call(
             ctxt, 'cancel_migration', migration_id=migration_id, force=force)
             ctxt, 'cancel_migration', migration_id=migration_id, force=force)
 
 
-    def set_task_host(self, ctxt, task_id, host, process_id):
+    def set_task_host(self, ctxt, task_id, host):
         self._client.call(
         self._client.call(
-            ctxt, 'set_task_host', task_id=task_id, host=host,
-            process_id=process_id)
+            ctxt, 'set_task_host', task_id=task_id, host=host)
+
+    def set_task_process(self, ctxt, task_id, process_id):
+        self._client.call(
+            ctxt, 'set_task_process', task_id=task_id, process_id=process_id)
 
 
     def task_completed(self, ctxt, task_id, task_result):
     def task_completed(self, ctxt, task_id, task_result):
         self._client.call(
         self._client.call(
@@ -308,3 +314,61 @@ class ConductorClient(object):
 
 
     def get_all_diagnostics(self, ctxt):
     def get_all_diagnostics(self, ctxt):
         return self._client.call(ctxt, 'get_all_diagnostics')
         return self._client.call(ctxt, 'get_all_diagnostics')
+
+    def create_region(
+            self, ctxt, region_name, description="", enabled=True):
+        return self._client.call(
+            ctxt, 'create_region',
+            region_name=region_name,
+            description=description,
+            enabled=enabled)
+
+    def get_regions(self, ctxt):
+        return self._client.call(ctxt, 'get_regions')
+
+    def get_region(self, ctxt, region_id):
+        return self._client.call(
+            ctxt, 'get_region', region_id=region_id)
+
+    def update_region(self, ctxt, region_id, updated_values):
+        return self._client.call(
+            ctxt, 'update_region',
+            region_id=region_id,
+            updated_values=updated_values)
+
+    def delete_region(self, ctxt, region_id):
+        return self._client.call(
+            ctxt, 'delete_region', region_id=region_id)
+
+    def register_service(
+            self, ctxt, host, binary, topic, enabled, mapped_regions,
+            providers=None, specs=None):
+        return self._client.call(
+            ctxt, 'register_service', host=host, binary=binary,
+            topic=topic, enabled=enabled, mapped_regions=mapped_regions,
+            providers=providers, specs=specs)
+
+    def check_service_registered(self, ctxt, host, binary, topic):
+        return self._client.call(
+            ctxt, 'check_service_registered', host=host, binary=binary,
+            topic=topic)
+
+    def refresh_service_status(self, ctxt, service_id):
+        return self._client.call(
+            ctxt, 'refresh_service_status', service_id=service_id)
+
+    def get_services(self, ctxt):
+        return self._client.call(ctxt, 'get_services')
+
+    def get_service(self, ctxt, service_id):
+        return self._client.call(
+            ctxt, 'get_service', service_id=service_id)
+
+    def update_service(self, ctxt, service_id, updated_values):
+        return self._client.call(
+            ctxt, 'update_service', service_id=service_id,
+            updated_values=updated_values)
+
+    def delete_service(self, ctxt, service_id):
+        return self._client.call(
+            ctxt, 'delete_service', service_id=service_id)

+ 582 - 70
coriolis/conductor/rpc/server.py

@@ -3,6 +3,8 @@
 
 
 import copy
 import copy
 import functools
 import functools
+import random
+import time
 import uuid
 import uuid
 
 
 from oslo_concurrency import lockutils
 from oslo_concurrency import lockutils
@@ -17,6 +19,7 @@ from coriolis import exception
 from coriolis import keystone
 from coriolis import keystone
 from coriolis.licensing import client as licensing_client
 from coriolis.licensing import client as licensing_client
 from coriolis.replica_cron.rpc import client as rpc_cron_client
 from coriolis.replica_cron.rpc import client as rpc_cron_client
+from coriolis.scheduler.rpc import client as rpc_scheduler_client
 from coriolis import schemas
 from coriolis import schemas
 from coriolis.tasks import factory as tasks_factory
 from coriolis.tasks import factory as tasks_factory
 from coriolis import utils
 from coriolis import utils
@@ -43,6 +46,14 @@ TASK_DEADLOCK_ERROR_MESSAGE = (
     "A fatal deadlock has occurred. Further debugging is required. "
     "A fatal deadlock has occurred. Further debugging is required. "
     "Please review the Conductor logs and contact support for assistance.")
     "Please review the Conductor logs and contact support for assistance.")
 
 
+RPC_TOPIC_TO_CLIENT_CLASS_MAP = {
+    constants.WORKER_MAIN_MESSAGING_TOPIC: rpc_worker_client.WorkerClient,
+    constants.SCHEDULER_MAIN_MESSAGING_TOPIC: (
+        rpc_scheduler_client.SchedulerClient),
+    constants.REPLICA_CRON_MAIN_MESSAGING_TOPIC: (
+        rpc_cron_client.ReplicaCronClient)
+}
+
 
 
 def endpoint_synchronized(func):
 def endpoint_synchronized(func):
     @functools.wraps(func)
     @functools.wraps(func)
@@ -132,21 +143,131 @@ def tasks_execution_synchronized(func):
     return wrapper
     return wrapper
 
 
 
 
+def region_synchronized(func):
+    @functools.wraps(func)
+    def wrapper(self, ctxt, region_id, *args, **kwargs):
+        @lockutils.synchronized(
+            constants.REGION_LOCK_NAME_FORMAT % region_id,
+            external=True)
+        def inner():
+            return func(self, ctxt, region_id, *args, **kwargs)
+        return inner()
+    return wrapper
+
+
+def service_synchronized(func):
+    @functools.wraps(func)
+    def wrapper(self, ctxt, service_id, *args, **kwargs):
+        @lockutils.synchronized(
+            constants.SERVICE_LOCK_NAME_FORMAT % service_id,
+            external=True)
+        def inner():
+            return func(self, ctxt, service_id, *args, **kwargs)
+        return inner()
+    return wrapper
+
+
 class ConductorServerEndpoint(object):
 class ConductorServerEndpoint(object):
     def __init__(self):
     def __init__(self):
         self._licensing_client = licensing_client.LicensingClient.from_env()
         self._licensing_client = licensing_client.LicensingClient.from_env()
-        self._rpc_worker_client = rpc_worker_client.WorkerClient()
-        self._replica_cron_client = rpc_cron_client.ReplicaCronClient()
+
+    # NOTE(aznashwan): it is unsafe to fork processes with pre-instantiated
+    # oslo_messaging clients as the underlying eventlet thread queues will
+    # be invalidated. Considering this class both serves from a "main
+    # process" as well as forking child processes, it is safest to
+    # re-instantiate the clients every time:
+    @property
+    def _rpc_worker_client(self):
+        return rpc_worker_client.WorkerClient()
+
+    @property
+    def _scheduler_client(self):
+        return rpc_scheduler_client.SchedulerClient()
+
+    @property
+    def _replica_cron_client(self):
+        return rpc_cron_client.ReplicaCronClient()
 
 
     def get_all_diagnostics(self, ctxt):
     def get_all_diagnostics(self, ctxt):
-        conductor = self.get_diagnostics(ctxt)
-        cron = self._replica_cron_client.get_diagnostics(ctxt)
-        worker = self._rpc_worker_client.get_diagnostics(ctxt)
-        return [
-            conductor,
-            cron,
-            worker,
-        ]
+        diagnostics = [
+            self.get_diagnostics(ctxt),
+            self._replica_cron_client.get_diagnostics(ctxt),
+            self._scheduler_client.get_diagnostics(ctxt)]
+        worker_diagnostics = []
+        for worker_service in self._scheduler_client.get_workers_for_specs(
+                ctxt):
+            worker_rpc = self._get_rpc_client_for_service(worker_service)
+            diagnostics.append(worker_rpc.get_diagnostics(ctxt))
+
+        return diagnostics
+
+    def _get_rpc_client_for_service(self, service, *client_args, **client_kwargs):
+        rpc_client_class = RPC_TOPIC_TO_CLIENT_CLASS_MAP.get(service.topic)
+        if not rpc_client_class:
+            raise exception.NotFound(
+                "No RPC client class for service with topic '%s'." % (
+                    service.topic))
+
+        topic = service.topic
+        if service.topic == constants.WORKER_MAIN_MESSAGING_TOPIC:
+            # NOTE: coriolis.service.MessagingService-type services (such
+            # as the worker), always have a dedicated per-host queue
+            # which can be used to target the service:
+            topic = constants.SERVICE_MESSAGING_TOPIC_FORMAT % ({
+                "main_topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
+                "host": service.host})
+
+        return rpc_client_class(*client_args, topic=topic, **client_kwargs)
+
+    def _get_any_worker_service(self, ctxt, random_choice=False, raw_dict=False):
+        services = self._scheduler_client.get_workers_for_specs(ctxt)
+        if not services:
+            raise exception.NoWorkerServiceError()
+        service = services[0]
+        if random_choice:
+            service = random.choice(services)
+        if raw_dict:
+            return service
+        return db_api.get_service(ctxt, service['id'])
+
+    def _get_worker_rpc_for_host(self, host, *client_args, **client_kwargs):
+        rpc_client_class = RPC_TOPIC_TO_CLIENT_CLASS_MAP[
+            constants.WORKER_MAIN_MESSAGING_TOPIC]
+        topic = constants.SERVICE_MESSAGING_TOPIC_FORMAT % ({
+            "main_topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
+            "host": host})
+        return rpc_client_class(topic=topic, *client_args, **client_kwargs)
+
+    def _get_worker_service_rpc_for_specs(
+            self, ctxt, provider_requirements=None, region_sets=None,
+            enabled=True, random_choice=False, raise_on_no_matches=True):
+        requirements_str = (
+            "enabled=%s; region_sets=%s; provider_requirements=%s" % (
+                enabled, region_sets, provider_requirements))
+        LOG.info(
+            "Requesting Worker Service from scheduler with the following "
+            "specifications: %s", requirements_str)
+        services = self._scheduler_client.get_workers_for_specs(
+            ctxt, provider_requirements=provider_requirements,
+            region_sets=region_sets, enabled=enabled)
+        if not services:
+            if raise_on_no_matches:
+                raise exception.NoSuitableWorkerServiceError()
+            return None
+        LOG.debug(
+            "Was offered Worker Services with the following IDs for "
+            "requirements '%s': %s",
+            requirements_str, [s["id"] for s in services])
+
+        selected_service = services[0]
+        if random_choice:
+            selected_service = random.choice(services)
+        service = db_api.get_service(ctxt, selected_service["id"])
+
+        LOG.info(
+            "Was offered Worker Service with ID '%s' for requirements: %s",
+            service.id, requirements_str)
+        return self._get_rpc_client_for_service(service)
 
 
     def _check_delete_reservation_for_transfer(self, transfer_action):
     def _check_delete_reservation_for_transfer(self, transfer_action):
         action_id = transfer_action.base_id
         action_id = transfer_action.base_id
@@ -204,8 +325,9 @@ class ConductorServerEndpoint(object):
                 "all reservation licensing checks.", action_id)
                 "all reservation licensing checks.", action_id)
 
 
     def create_endpoint(self, ctxt, name, endpoint_type, description,
     def create_endpoint(self, ctxt, name, endpoint_type, description,
-                        connection_info):
+                        connection_info, mapped_regions=None):
         endpoint = models.Endpoint()
         endpoint = models.Endpoint()
+        endpoint.id = str(uuid.uuid4())
         endpoint.name = name
         endpoint.name = name
         endpoint.type = endpoint_type
         endpoint.type = endpoint_type
         endpoint.description = description
         endpoint.description = description
@@ -213,12 +335,31 @@ class ConductorServerEndpoint(object):
 
 
         db_api.add_endpoint(ctxt, endpoint)
         db_api.add_endpoint(ctxt, endpoint)
         LOG.info("Endpoint created: %s", endpoint.id)
         LOG.info("Endpoint created: %s", endpoint.id)
+
+        # add region associations:
+        if mapped_regions:
+            try:
+                db_api.update_endpoint(
+                    ctxt, endpoint.id, {
+                        "mapped_regions": mapped_regions})
+            except Exception as ex:
+                LOG.warn(
+                    "Error adding region mappings during new endpoint creation "
+                    "(name: %s), cleaning up endpoint and all created "
+                    "mappings for regions: %s", endpoint.name, mapped_regions)
+                db_api.delete_endpoint(ctxt, endpoint.id)
+                raise
+
         return self.get_endpoint(ctxt, endpoint.id)
         return self.get_endpoint(ctxt, endpoint.id)
 
 
+    @endpoint_synchronized
     def update_endpoint(self, ctxt, endpoint_id, updated_values):
     def update_endpoint(self, ctxt, endpoint_id, updated_values):
+        LOG.info(
+            "Attempting to update endpoint '%s' with payload: %s",
+            endpoint_id, updated_values)
         db_api.update_endpoint(ctxt, endpoint_id, updated_values)
         db_api.update_endpoint(ctxt, endpoint_id, updated_values)
         LOG.info("Endpoint updated: %s", endpoint_id)
         LOG.info("Endpoint updated: %s", endpoint_id)
-        return self.get_endpoint(ctxt, endpoint_id)
+        return db_api.get_endpoint(ctxt, endpoint_id)
 
 
     def get_endpoints(self, ctxt):
     def get_endpoints(self, ctxt):
         return db_api.get_endpoints(ctxt)
         return db_api.get_endpoints(ctxt)
@@ -244,7 +385,12 @@ class ConductorServerEndpoint(object):
                                marker, limit, instance_name_pattern):
                                marker, limit, instance_name_pattern):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
 
 
-        return self._rpc_worker_client.get_endpoint_instances(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_INSTANCES]})
+        return worker_rpc.get_endpoint_instances(
             ctxt, endpoint.type, endpoint.connection_info,
             ctxt, endpoint.type, endpoint.connection_info,
             source_environment, marker, limit, instance_name_pattern)
             source_environment, marker, limit, instance_name_pattern)
 
 
@@ -252,7 +398,13 @@ class ConductorServerEndpoint(object):
             self, ctxt, endpoint_id, source_environment, instance_name):
             self, ctxt, endpoint_id, source_environment, instance_name):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
 
 
-        return self._rpc_worker_client.get_endpoint_instance(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_INSTANCES]})
+
+        return worker_rpc.get_endpoint_instance(
             ctxt, endpoint.type, endpoint.connection_info,
             ctxt, endpoint.type, endpoint.connection_info,
             source_environment, instance_name)
             source_environment, instance_name)
 
 
@@ -260,50 +412,102 @@ class ConductorServerEndpoint(object):
             self, ctxt, endpoint_id, env, option_names):
             self, ctxt, endpoint_id, env, option_names):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
 
 
-        return self._rpc_worker_client.get_endpoint_source_options(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [
+                    constants.PROVIDER_TYPE_SOURCE_ENDPOINT_OPTIONS]})
+
+        return worker_rpc.get_endpoint_source_options(
             ctxt, endpoint.type, endpoint.connection_info, env, option_names)
             ctxt, endpoint.type, endpoint.connection_info, env, option_names)
 
 
     def get_endpoint_destination_options(
     def get_endpoint_destination_options(
             self, ctxt, endpoint_id, env, option_names):
             self, ctxt, endpoint_id, env, option_names):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
 
 
-        return self._rpc_worker_client.get_endpoint_destination_options(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [
+                    constants.PROVIDER_TYPE_DESTINATION_ENDPOINT_OPTIONS]})
+        return worker_rpc.get_endpoint_destination_options(
             ctxt, endpoint.type, endpoint.connection_info, env, option_names)
             ctxt, endpoint.type, endpoint.connection_info, env, option_names)
 
 
     def get_endpoint_networks(self, ctxt, endpoint_id, env):
     def get_endpoint_networks(self, ctxt, endpoint_id, env):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
 
 
-        return self._rpc_worker_client.get_endpoint_networks(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_NETWORKS]})
+
+        return worker_rpc.get_endpoint_networks(
             ctxt, endpoint.type, endpoint.connection_info, env)
             ctxt, endpoint.type, endpoint.connection_info, env)
 
 
     def get_endpoint_storage(self, ctxt, endpoint_id, env):
     def get_endpoint_storage(self, ctxt, endpoint_id, env):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
 
 
-        return self._rpc_worker_client.get_endpoint_storage(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT_STORAGE]})
+
+        return worker_rpc.get_endpoint_storage(
             ctxt, endpoint.type, endpoint.connection_info, env)
             ctxt, endpoint.type, endpoint.connection_info, env)
 
 
     def validate_endpoint_connection(self, ctxt, endpoint_id):
     def validate_endpoint_connection(self, ctxt, endpoint_id):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
-        return self._rpc_worker_client.validate_endpoint_connection(
+
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
+
+        return worker_rpc.validate_endpoint_connection(
             ctxt, endpoint.type, endpoint.connection_info)
             ctxt, endpoint.type, endpoint.connection_info)
 
 
     def validate_endpoint_target_environment(
     def validate_endpoint_target_environment(
             self, ctxt, endpoint_id, target_env):
             self, ctxt, endpoint_id, target_env):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
-        return self._rpc_worker_client.validate_endpoint_target_environment(
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
+
+        return worker_rpc.validate_endpoint_target_environment(
             ctxt, endpoint.type, target_env)
             ctxt, endpoint.type, target_env)
 
 
     def validate_endpoint_source_environment(
     def validate_endpoint_source_environment(
             self, ctxt, endpoint_id, source_env):
             self, ctxt, endpoint_id, source_env):
         endpoint = self.get_endpoint(ctxt, endpoint_id)
         endpoint = self.get_endpoint(ctxt, endpoint_id)
-        return self._rpc_worker_client.validate_endpoint_source_environment(
+
+        worker_rpc = self._get_worker_service_rpc_for_specs(
+            ctxt, enabled=True,
+            region_sets=[[reg.id for reg in endpoint.mapped_regions]],
+            provider_requirements={
+                endpoint.type: [constants.PROVIDER_TYPE_ENDPOINT]})
+
+        return worker_rpc.validate_endpoint_source_environment(
             ctxt, endpoint.type, source_env)
             ctxt, endpoint.type, source_env)
 
 
     def get_available_providers(self, ctxt):
     def get_available_providers(self, ctxt):
-        return self._rpc_worker_client.get_available_providers(ctxt)
+        # TODO(aznashwan): merge list of all providers from all
+        # worker services:
+        worker_rpc = self._get_rpc_client_for_service(
+            self._get_any_worker_service(ctxt))
+        return worker_rpc.get_available_providers(ctxt)
 
 
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
-        return self._rpc_worker_client.get_provider_schemas(
+        # TODO(aznashwan): merge or version/namespace schemas for each worker?
+        worker_rpc = self._get_rpc_client_for_service(
+            self._get_any_worker_service(ctxt))
+        return worker_rpc.get_provider_schemas(
             ctxt, platform_name, provider_type)
             ctxt, platform_name, provider_type)
 
 
     @staticmethod
     @staticmethod
@@ -360,7 +564,85 @@ class ConductorServerEndpoint(object):
             "target_environment": action.destination_environment
             "target_environment": action.destination_environment
         }
         }
 
 
-    def _begin_tasks(self, ctxt, execution, task_info={}):
+    def _get_worker_service_rpc_for_task(
+            self, ctxt, task, origin_endpoint, destination_endpoint,
+            retry_count=5, retry_period=2):
+        LOG.debug(
+            "Compiling required Worker Service specs for task with "
+            "ID '%s' (type '%s') from endpoints '%s' to '%s'",
+            task.id, task.task_type, origin_endpoint.id,
+            destination_endpoint.id)
+        task_cls = tasks_factory.get_task_runner_class(task.task_type)
+
+        # determine required Coriolis regions based on the endpoints:
+        required_region_sets = []
+        origin_endpoint_region_ids = [
+            r.id for r in origin_endpoint.mapped_regions]
+        destination_endpoint_region_ids = [
+            r.id for r in destination_endpoint.mapped_regions]
+
+        required_platform = task_cls.get_required_platform()
+        if required_platform in (
+                constants.TASK_PLATFORM_SOURCE,
+                constants.TASK_PLATFORM_BILATERAL):
+            required_region_sets.append(origin_endpoint_region_ids)
+        if required_platform in (
+                constants.TASK_PLATFORM_DESTINATION,
+                constants.TASK_PLATFORM_BILATERAL):
+            required_region_sets.append(destination_endpoint_region_ids)
+
+        # determine provider requirements:
+        provider_requirements = {}
+        required_provider_types = task_cls.get_required_provider_types()
+        if constants.PROVIDER_PLATFORM_SOURCE in required_provider_types:
+            provider_requirements[origin_endpoint.type] = (
+                required_provider_types[
+                    constants.PROVIDER_PLATFORM_SOURCE])
+        if constants.PROVIDER_PLATFORM_DESTINATION in required_provider_types:
+            provider_requirements[destination_endpoint.type] = (
+                required_provider_types[
+                    constants.PROVIDER_PLATFORM_DESTINATION])
+
+        worker_rpc = None
+        for i in range(retry_count):
+            try:
+                LOG.debug(
+                    "Requesting Worker Service for task with ID '%s' (type "
+                    "'%s') from endpoints '%s' to '%s'", task.id,
+                    task.task_type, origin_endpoint.id,
+                    destination_endpoint.id)
+                worker_rpc = self._get_worker_service_rpc_for_specs(
+                    ctxt, provider_requirements=provider_requirements,
+                    region_sets=required_region_sets, enabled=True)
+                LOG.debug(
+                    "Scheduler has granted Worker Service for task with ID "
+                    "'%s' (type '%s') from endpoints '%s' to '%s'",
+                    task.id, task.task_type, origin_endpoint.id,
+                    destination_endpoint.id)
+                return worker_rpc
+            except Exception as ex:
+                LOG.warn(
+                    "Failed to schedule task with ID '%s' (attempt %d/%d). "
+                    "Waiting %d seconds and then retrying. Error was: %s",
+                    task.id, i+1, retry_count, retry_period,
+                    utils.get_exception_details())
+                time.sleep(retry_period)
+
+        message = (
+            "Failed to schedule task %s after %d tries. This may indicate that"
+            " there are no Coriolis Worker services able to perform the task "
+            "on the platforms and in the Coriolis Regions required by the "
+            "selected source/destination Coriolis Endpoints. Please review"
+            " the Conductor and Scheduler logs for more exact details." % (
+                task.id, retry_count))
+        db_api.set_task_status(
+            ctxt, task.id, constants.TASK_STATUS_FAILED_TO_SCHEDULE,
+            exception_details=message)
+        raise exception.NoSuitableWorkerServiceError(message)
+
+    def _begin_tasks(
+            self, ctxt, execution, task_info={},
+            scheduling_retry_count=5, scheduling_retry_period=2):
         """ Starts all non-error-only tasks which have no depencies. """
         """ Starts all non-error-only tasks which have no depencies. """
         if not ctxt.trust_id:
         if not ctxt.trust_id:
             keystone.create_trust(ctxt)
             keystone.create_trust(ctxt)
@@ -368,6 +650,10 @@ class ConductorServerEndpoint(object):
 
 
         origin = self._get_task_origin(ctxt, execution.action)
         origin = self._get_task_origin(ctxt, execution.action)
         destination = self._get_task_destination(ctxt, execution.action)
         destination = self._get_task_destination(ctxt, execution.action)
+        origin_endpoint = db_api.get_endpoint(
+            ctxt, execution.action.origin_endpoint_id)
+        destination_endpoint = db_api.get_endpoint(
+            ctxt, execution.action.destination_endpoint_id)
 
 
         newly_started_tasks = []
         newly_started_tasks = []
         for task in execution.tasks:
         for task in execution.tasks:
@@ -378,14 +664,27 @@ class ConductorServerEndpoint(object):
                     task.id, execution.id)
                     task.id, execution.id)
                 db_api.set_task_status(
                 db_api.set_task_status(
                     ctxt, task.id, constants.TASK_STATUS_PENDING)
                     ctxt, task.id, constants.TASK_STATUS_PENDING)
-                self._rpc_worker_client.begin_task(
-                    ctxt, server=None,
-                    task_id=task.id,
-                    task_type=task.task_type,
-                    origin=origin,
-                    destination=destination,
-                    instance=task.instance,
-                    task_info=task_info.get(task.instance, {}))
+                try:
+                    worker_rpc = self._get_worker_service_rpc_for_task(
+                        ctxt, task, origin_endpoint, destination_endpoint,
+                        retry_count=scheduling_retry_count,
+                        retry_period=scheduling_retry_period)
+                    worker_rpc.begin_task(
+                        ctxt, server=None,
+                        task_id=task.id,
+                        task_type=task.task_type,
+                        origin=origin,
+                        destination=destination,
+                        instance=task.instance,
+                        task_info=task_info.get(task.instance, {}))
+                except Exception as ex:
+                    LOG.warn(
+                        "Error occured while starting new task '%s'. "
+                        "Cancelling execution '%s'. Error was: %s",
+                        task.id, execution.id, utils.get_exception_details())
+                    self._cancel_tasks_execution(
+                        ctxt, execution, requery=True)
+                    raise
                 newly_started_tasks.append(task.id)
                 newly_started_tasks.append(task.id)
 
 
         # NOTE: this should never happen if _check_execution_tasks_sanity
         # NOTE: this should never happen if _check_execution_tasks_sanity
@@ -409,9 +708,9 @@ class ConductorServerEndpoint(object):
             for instance in all_instances_in_tasks}
             for instance in all_instances_in_tasks}
 
 
         def _check_task_cls_param_requirements(task, instance_task_info_keys):
         def _check_task_cls_param_requirements(task, instance_task_info_keys):
-            task_cls = tasks_factory.get_task_runner_class(task.task_type)()
+            task_cls = tasks_factory.get_task_runner_class(task.task_type)
             missing_params = [
             missing_params = [
-                p for p in task_cls.required_task_info_properties
+                p for p in task_cls.get_required_task_info_properties()
                 if p not in instance_task_info_keys]
                 if p not in instance_task_info_keys]
             if missing_params:
             if missing_params:
                 raise exception.CoriolisException(
                 raise exception.CoriolisException(
@@ -420,7 +719,7 @@ class ConductorServerEndpoint(object):
                     "type '%s': %s" % (
                     "type '%s': %s" % (
                         task.instance, task.id, task.task_type,
                         task.instance, task.id, task.task_type,
                         missing_params))
                         missing_params))
-            return task_cls.returned_task_info_properties
+            return task_cls.get_returned_task_info_properties()
 
 
         for instance, instance_tasks in instances_tasks_mapping.items():
         for instance, instance_tasks in instances_tasks_mapping.items():
             task_info_keys = set(initial_task_info.get(
             task_info_keys = set(initial_task_info.get(
@@ -746,7 +1045,7 @@ class ConductorServerEndpoint(object):
 
 
         self._check_execution_tasks_sanity(execution, replica.info)
         self._check_execution_tasks_sanity(execution, replica.info)
 
 
-        # update the action info for all of the Replicas' instnaces:
+        # update the action info for all of the Replicas' instances:
         for instance in replica.instances:
         for instance in replica.instances:
             db_api.update_transfer_action_info_for_instance(
             db_api.update_transfer_action_info_for_instance(
                 ctxt, replica.id, instance, replica.info[instance])
                 ctxt, replica.id, instance, replica.info[instance])
@@ -758,6 +1057,12 @@ class ConductorServerEndpoint(object):
 
 
     @staticmethod
     @staticmethod
     def _check_endpoints(ctxt, origin_endpoint, destination_endpoint):
     def _check_endpoints(ctxt, origin_endpoint, destination_endpoint):
+        if origin_endpoint.id == destination_endpoint.id:
+            raise exception.SameDestination(
+                "The origin and destination endpoints cannot be the same. "
+                "If you need to perform operations across two areas of "
+                "the same platform (ex: migrating across public cloud regions)"
+                ", please create two separate endpoints.")
         # TODO(alexpilotti): check Barbican secrets content as well
         # TODO(alexpilotti): check Barbican secrets content as well
         if (origin_endpoint.connection_info ==
         if (origin_endpoint.connection_info ==
                 destination_endpoint.connection_info):
                 destination_endpoint.connection_info):
@@ -773,8 +1078,8 @@ class ConductorServerEndpoint(object):
 
 
         replica = models.Replica()
         replica = models.Replica()
         replica.id = str(uuid.uuid4())
         replica.id = str(uuid.uuid4())
-        replica.origin_endpoint = origin_endpoint
-        replica.destination_endpoint = destination_endpoint
+        replica.origin_endpoint_id = origin_endpoint_id
+        replica.destination_endpoint_id = destination_endpoint_id
         replica.destination_environment = destination_environment
         replica.destination_environment = destination_environment
         replica.source_environment = source_environment
         replica.source_environment = source_environment
         replica.instances = instances
         replica.instances = instances
@@ -1036,8 +1341,8 @@ class ConductorServerEndpoint(object):
 
 
         migration = models.Migration()
         migration = models.Migration()
         migration.id = str(uuid.uuid4())
         migration.id = str(uuid.uuid4())
-        migration.origin_endpoint = origin_endpoint
-        migration.destination_endpoint = destination_endpoint
+        migration.origin_endpoint_id = origin_endpoint_id
+        migration.destination_endpoint_id = destination_endpoint_id
         migration.destination_environment = destination_environment
         migration.destination_environment = destination_environment
         migration.source_environment = source_environment
         migration.source_environment = source_environment
         migration.network_map = network_map
         migration.network_map = network_map
@@ -1318,8 +1623,24 @@ class ConductorServerEndpoint(object):
                 continue
                 continue
 
 
             if task.status in (
             if task.status in (
-                    constants.TASK_STATUS_RUNNING,
-                    constants.TASK_STATUS_PENDING):
+                    constants.TASK_STATUS_PENDING,
+                    constants.TASK_STATUS_STARTING):
+                # any PENDING/STARTING tasks means that they did not have a
+                # host assigned to them yet, and presuming the host does not
+                # start executing the task until it marks itself as the runner,
+                # we can just mark the task as cancelled:
+                LOG.debug(
+                    "Setting currently '%s' task '%s' to '%s' as part of the "
+                    "cancellation of execution '%s'",
+                    task.status, task.id,
+                    constants.TASK_STATUS_UNSCHEDULED, execution.id)
+                db_api.set_task_status(
+                    ctxt, task.id, constants.TASK_STATUS_UNSCHEDULED,
+                    exception_details=(
+                        "This task was already pending execution but was "
+                        "unscheduled during the cancellation of the parent "
+                        "tasks execution."))
+            elif task.status == constants.TASK_STATUS_RUNNING:
                 # cancel any currently running/pending non-error tasks:
                 # cancel any currently running/pending non-error tasks:
                 if not task.on_error:
                 if not task.on_error:
                     LOG.debug(
                     LOG.debug(
@@ -1328,7 +1649,8 @@ class ConductorServerEndpoint(object):
                         task.status, task.id, execution.id)
                         task.status, task.id, execution.id)
                     db_api.set_task_status(
                     db_api.set_task_status(
                         ctxt, task.id, constants.TASK_STATUS_CANCELLING)
                         ctxt, task.id, constants.TASK_STATUS_CANCELLING)
-                    self._rpc_worker_client.cancel_task(
+                    worker_rpc = self._get_worker_rpc_for_host(task.host)
+                    worker_rpc.cancel_task(
                         ctxt, task.host, task.id, task.process_id, force)
                         ctxt, task.host, task.id, task.process_id, force)
                 # let any on-error tasks run to completion but mark
                 # let any on-error tasks run to completion but mark
                 # them as CANCELLING_AFTER_COMPLETION so they will
                 # them as CANCELLING_AFTER_COMPLETION so they will
@@ -1361,7 +1683,8 @@ class ConductorServerEndpoint(object):
                     "execution '%s'",
                     "execution '%s'",
                     task.id, task.status, task.on_error, execution.id)
                     task.id, task.status, task.on_error, execution.id)
 
 
-        started_tasks = self._advance_execution_state(ctxt, execution)
+        started_tasks = self._advance_execution_state(
+            ctxt, execution, requery=True)
         if started_tasks:
         if started_tasks:
             LOG.info(
             LOG.info(
                 "The following tasks were started after state advancement "
                 "The following tasks were started after state advancement "
@@ -1382,11 +1705,11 @@ class ConductorServerEndpoint(object):
             keystone.delete_trust(ctxt)
             keystone.delete_trust(ctxt)
 
 
     @parent_tasks_execution_synchronized
     @parent_tasks_execution_synchronized
-    def set_task_host(self, ctxt, task_id, host, process_id):
-        """ Saves the ID of the worker host which has accepted and started
-        the task to the DB and marks the task as 'RUNNING'. """
+    def set_task_host(self, ctxt, task_id, host):
+        """ Saves the ID of the worker host which has accepted
+        the task to the DB and marks the task as STARTING. """
         task = db_api.get_task(ctxt, task_id)
         task = db_api.get_task(ctxt, task_id)
-        new_status = constants.TASK_STATUS_RUNNING
+        new_status = constants.TASK_STATUS_STARTING
         exception_details = None
         exception_details = None
         if task.status == constants.TASK_STATUS_CANCELLING:
         if task.status == constants.TASK_STATUS_CANCELLING:
             raise exception.TaskIsCancelling(task_id=task_id)
             raise exception.TaskIsCancelling(task_id=task_id)
@@ -1409,10 +1732,45 @@ class ConductorServerEndpoint(object):
                 "Task with ID '%s' is in '%s' status instead of the "
                 "Task with ID '%s' is in '%s' status instead of the "
                 "expected '%s' required for it to have a task host set." % (
                 "expected '%s' required for it to have a task host set." % (
                     task_id, task.status, constants.TASK_STATUS_PENDING))
                     task_id, task.status, constants.TASK_STATUS_PENDING))
-        db_api.set_task_host(ctxt, task_id, host, process_id)
+        LOG.info(
+            "Setting host for task with ID '%s' to '%s'", task_id, host)
+        db_api.set_task_host_properties(ctxt, task_id, host=host)
         db_api.set_task_status(
         db_api.set_task_status(
             ctxt, task_id, new_status,
             ctxt, task_id, new_status,
             exception_details=exception_details)
             exception_details=exception_details)
+        LOG.info(
+            "Successfully set host for task with ID '%s' to '%s'",
+            task_id, host)
+
+    @parent_tasks_execution_synchronized
+    def set_task_process(self, ctxt, task_id, process_id):
+        """ Sets the ID of the Worker-side process for the given task,
+        and marks the task as actually 'RUNNING'. """
+        task = db_api.get_task(ctxt, task_id)
+        if not task.host:
+            raise exception.InvalidTaskState(
+                "Task with ID '%s' (current status '%s') has no host set "
+                "for it. Cannot set host process." % (
+                    task_id, task.status))
+        acceptable_statuses = [
+            constants.TASK_STATUS_STARTING,
+            constants.TASK_STATUS_CANCELLING_AFTER_COMPLETION]
+        if task.status not in acceptable_statuses:
+            raise exception.InvalidTaskState(
+                "Task with ID '%s' is in '%s' status instead of the "
+                "expected statuses (%s) required for it to have a task "
+                "process set." % (
+                    task_id, task.status, acceptable_statuses))
+
+        LOG.info(
+            "Setting process '%s' (host %s) for task '%s' and transitioning "
+            "it from status '%s' to '%s'", process_id, task.host, task_id,
+            task.status, constants.TASK_STATUS_RUNNING)
+        db_api.set_task_host_properties(ctxt, task_id, process_id=process_id)
+        db_api.set_task_status(ctxt, task_id, constants.TASK_STATUS_RUNNING)
+        LOG.info(
+            "Successfully set task process for task with ID '%s' to '%s'",
+            task_id, process_id)
 
 
     def _check_clean_execution_deadlock(
     def _check_clean_execution_deadlock(
             self, ctxt, execution, task_statuses=None, requery=True):
             self, ctxt, execution, task_statuses=None, requery=True):
@@ -1464,7 +1822,7 @@ class ConductorServerEndpoint(object):
 
 
     def _get_execution_status(self, ctxt, execution, requery=False):
     def _get_execution_status(self, ctxt, execution, requery=False):
         """ Returns the global status of an execution.
         """ Returns the global status of an execution.
-        RUNNING - at least one task is RUNNING, PENDING or CANCELLING
+        RUNNING - at least one task is RUNNING, STARTING, PENDING or CANCELLING
         COMPLETED - all non-error-only tasks are COMPLETED
         COMPLETED - all non-error-only tasks are COMPLETED
         CANCELED - no more RUNNING/PENDING/SCHEDULED tasks but some CANCELED
         CANCELED - no more RUNNING/PENDING/SCHEDULED tasks but some CANCELED
         CANCELIING - at least one task in CANCELLING status
         CANCELIING - at least one task in CANCELLING status
@@ -1485,7 +1843,9 @@ class ConductorServerEndpoint(object):
                 is_running = True
                 is_running = True
             if task.status in constants.CANCELED_TASK_STATUSES:
             if task.status in constants.CANCELED_TASK_STATUSES:
                 is_canceled = True
                 is_canceled = True
-            if task.status == constants.TASK_STATUS_ERROR:
+            if task.status in (
+                    constants.TASK_STATUS_ERROR,
+                    constants.TASK_STATUS_FAILED_TO_SCHEDULE):
                 is_errord = True
                 is_errord = True
             if task.status in (
             if task.status in (
                     constants.TASK_STATUS_CANCELLING,
                     constants.TASK_STATUS_CANCELLING,
@@ -1577,6 +1937,10 @@ class ConductorServerEndpoint(object):
         origin = self._get_task_origin(ctxt, execution.action)
         origin = self._get_task_origin(ctxt, execution.action)
         destination = self._get_task_destination(ctxt, execution.action)
         destination = self._get_task_destination(ctxt, execution.action)
         action = db_api.get_action(ctxt, execution.action_id)
         action = db_api.get_action(ctxt, execution.action_id)
+        origin_endpoint = db_api.get_endpoint(
+            ctxt, execution.action.origin_endpoint_id)
+        destination_endpoint = db_api.get_endpoint(
+            ctxt, execution.action.destination_endpoint_id)
 
 
         started_tasks = []
         started_tasks = []
 
 
@@ -1594,15 +1958,31 @@ class ConductorServerEndpoint(object):
                 task_info = action.info[task.instance]
                 task_info = action.info[task.instance]
             db_api.set_task_status(
             db_api.set_task_status(
                 ctxt, task.id, constants.TASK_STATUS_PENDING)
                 ctxt, task.id, constants.TASK_STATUS_PENDING)
-            self._rpc_worker_client.begin_task(
-                ctxt, server=None,
-                task_id=task.id,
-                task_type=task.task_type,
-                origin=origin,
-                destination=destination,
-                instance=task.instance,
-                task_info=task_info)
-            started_tasks.append(task.id)
+            try:
+                worker_rpc = self._get_worker_service_rpc_for_task(
+                    ctxt, task, origin_endpoint, destination_endpoint)
+                worker_rpc.begin_task(
+                    ctxt, server=None,
+                    task_id=task.id,
+                    task_type=task.task_type,
+                    origin=origin,
+                    destination=destination,
+                    instance=task.instance,
+                    task_info=task_info)
+                LOG.debug(
+                    "Successfully started task with ID '%s' (type '%s') "
+                    "for execution '%s'", task.id, task.task_type,
+                    execution.id)
+                started_tasks.append(task.id)
+                return constants.TASK_STATUS_PENDING
+            except Exception as ex:
+                LOG.warn(
+                    "Error occured while starting new task '%s'. "
+                    "Cancelling execution '%s'. Error was: %s",
+                    task.id, execution.id, utils.get_exception_details())
+                self._cancel_tasks_execution(
+                    ctxt, execution, requery=True)
+                raise
 
 
         # aggregate all tasks and statuses:
         # aggregate all tasks and statuses:
         task_statuses = {}
         task_statuses = {}
@@ -1622,7 +2002,7 @@ class ConductorServerEndpoint(object):
         LOG.debug(
         LOG.debug(
             "All task statuses before execution '%s' lifecycle iteration "
             "All task statuses before execution '%s' lifecycle iteration "
             "(for tasks of instance '%s'): %s",
             "(for tasks of instance '%s'): %s",
-            instance, execution.id, task_statuses)
+            execution.id, instance, task_statuses)
 
 
         # NOTE: the tasks are saved in a random order in the DB, which
         # NOTE: the tasks are saved in a random order in the DB, which
         # complicates the processing logic so we just pre-sort:
         # complicates the processing logic so we just pre-sort:
@@ -1634,8 +2014,7 @@ class ConductorServerEndpoint(object):
                 if not task_deps[task.id]:
                 if not task_deps[task.id]:
                     LOG.info(
                     LOG.info(
                         "Starting depency-less task '%s'", task.id)
                         "Starting depency-less task '%s'", task.id)
-                    _start_task(task)
-                    task_statuses[task.id] = constants.TASK_STATUS_PENDING
+                    task_statuses[task.id] = _start_task(task)
                     continue
                     continue
 
 
                 parent_task_statuses = {
                 parent_task_statuses = {
@@ -1676,9 +2055,7 @@ class ConductorServerEndpoint(object):
                                 "Starting task '%s' as all dependencies have "
                                 "Starting task '%s' as all dependencies have "
                                 "completed successfully: %s",
                                 "completed successfully: %s",
                                 task.id, parent_task_statuses)
                                 task.id, parent_task_statuses)
-                            _start_task(task)
-                            task_statuses[task.id] = (
-                                constants.TASK_STATUS_PENDING)
+                            task_statuses[task.id] = _start_task(task)
                         else:
                         else:
                             # it means one/more parents error'd/unscheduled
                             # it means one/more parents error'd/unscheduled
                             # so we mark this task as unscheduled:
                             # so we mark this task as unscheduled:
@@ -1712,9 +2089,7 @@ class ConductorServerEndpoint(object):
                                 "non-error parent (%s) was completed: %s",
                                 "non-error parent (%s) was completed: %s",
                                 task.id, list(non_error_parents.keys()),
                                 task.id, list(non_error_parents.keys()),
                                 parent_task_statuses)
                                 parent_task_statuses)
-                            _start_task(task)
-                            task_statuses[task.id] = (
-                                constants.TASK_STATUS_PENDING)
+                            task_statuses[task.id] = _start_task(task)
                         else:
                         else:
                             LOG.info(
                             LOG.info(
                                 "Unscheduling on-error task '%s' as none of "
                                 "Unscheduling on-error task '%s' as none of "
@@ -2007,7 +2382,7 @@ class ConductorServerEndpoint(object):
                 ctxt, task, execution, updated_task_info)
                 ctxt, task, execution, updated_task_info)
 
 
             newly_started_tasks = self._advance_execution_state(
             newly_started_tasks = self._advance_execution_state(
-                ctxt, execution, instance=task.instance)
+                ctxt, execution, instance=task.instance, requery=False)
             if newly_started_tasks:
             if newly_started_tasks:
                 LOG.info(
                 LOG.info(
                     "The following tasks were started for execution '%s' "
                     "The following tasks were started for execution '%s' "
@@ -2377,3 +2752,140 @@ class ConductorServerEndpoint(object):
 
 
     def get_diagnostics(self, ctxt):
     def get_diagnostics(self, ctxt):
         return utils.get_diagnostics_info()
         return utils.get_diagnostics_info()
+
+    def create_region(self, ctxt, region_name, description="", enabled=True):
+        region = models.Region()
+        region.id = str(uuid.uuid4())
+        region.name = region_name
+        region.description = description
+        region.enabled = enabled
+        db_api.add_region(ctxt, region)
+        return self.get_region(ctxt, region.id)
+
+    def get_regions(self, ctxt):
+        return db_api.get_regions(ctxt)
+
+    @region_synchronized
+    def get_region(self, ctxt, region_id):
+        region = db_api.get_region(ctxt, region_id)
+        if not region:
+            raise exception.NotFound(
+                "Region with ID '%s' not found." % region_id)
+        return region
+
+    @region_synchronized
+    def update_region(self, ctxt, region_id, updated_values):
+        LOG.info(
+            "Attempting to update region '%s' with payload: %s",
+            region_id, updated_values)
+        db_api.update_region(ctxt, region_id, updated_values)
+        LOG.info("Region '%s' successfully updated", region_id)
+        return db_api.get_region(ctxt, region_id)
+
+    @region_synchronized
+    def delete_region(self, ctxt, region_id):
+        # TODO(aznashwan): add checks for endpoints/services
+        # associated to the region before deletion:
+        db_api.delete_region(ctxt, region_id)
+
+    def register_service(
+            self, ctxt, host, binary, topic, enabled, mapped_regions=None,
+            providers=None, specs=None):
+        service = db_api.find_service(ctxt, host, binary, topic=topic)
+        if service:
+            raise exception.Conflict(
+                "A Service with the specified parameters (host %s, binary %s, "
+                "topic %s) has already been registered under ID: %s" % (
+                    host, binary, topic, service.id))
+
+        service = models.Service()
+        service.id = str(uuid.uuid4())
+        service.host = host
+        service.binary = binary
+        service.enabled = enabled
+        service.topic = topic
+        service.status = constants.SERVICE_STATUS_UP
+
+        if None in (providers, specs):
+            worker_rpc = self._get_rpc_client_for_service(service)
+            status = worker_rpc.get_service_status(ctxt)
+
+            service.providers = status["providers"]
+            service.specs = status["specs"]
+        else:
+            service.providers = providers
+            service.specs = specs
+
+        # create the service:
+        db_api.add_service(ctxt, service)
+        LOG.debug(
+            "Added new service to DB: %s", service.id)
+
+        # add region associations:
+        if mapped_regions:
+            try:
+                db_api.update_service(
+                    ctxt, service.id, {
+                        "mapped_regions": mapped_regions})
+            except Exception as ex:
+                LOG.warn(
+                    "Error adding region mappings during new service "
+                    "registration (host: %s), cleaning up endpoint and "
+                    "all created mappings for regions: %s",
+                    service.host, mapped_regions)
+                db_api.delete_service(ctxt, service.id)
+                raise
+
+        return self.get_service(ctxt, service.id)
+
+    def check_service_registered(self, ctxt, host, binary, topic):
+        props = "host='%s', binary='%s', topic='%s'" % (host, binary, topic)
+        LOG.debug(
+            "Checking for existence of service with properties: %s", props)
+        service = db_api.find_service(ctxt, host, binary, topic=topic)
+        if service:
+            LOG.debug(
+                "Found service '%s' for properties %s", service.id, props)
+        else:
+            LOG.debug(
+                "Could not find any service with the specified "
+                "properties: %s", props)
+        return service
+
+    @service_synchronized
+    def refresh_service_status(self, ctxt, service_id):
+        LOG.debug("Updating registration for worker service '%s'", service_id)
+        service = db_api.get_service(ctxt, service_id)
+        worker_rpc = self._get_rpc_client_for_service(service)
+        status = worker_rpc.get_service_status(ctxt)
+        updated_values = {
+            "providers": status["providers"],
+            "specs": status["specs"],
+            "status": constants.SERVICE_STATUS_UP}
+        db_api.update_service(ctxt, service_id, updated_values)
+        LOG.debug("Successfully refreshed status of service '%s'", service_id)
+        return db_api.get_service(ctxt, service_id)
+
+    def get_services(self, ctxt):
+        return db_api.get_services(ctxt)
+
+    @service_synchronized
+    def get_service(self, ctxt, service_id):
+        service = db_api.get_service(ctxt, service_id)
+        if not service:
+            raise exception.NotFound(
+                "Service with ID '%s' not found." % service_id)
+        return service
+
+    @service_synchronized
+    def update_service(self, ctxt, service_id, updated_values):
+        LOG.info(
+            "Attempting to update service '%s' with payload: %s",
+            service_id, updated_values)
+        db_api.update_service(ctxt, service_id, updated_values)
+        LOG.info("Successfully updated service '%s'", service_id)
+        return db_api.get_service(ctxt, service_id)
+
+    @service_synchronized
+    def delete_service(self, ctxt, service_id):
+        db_api.delete_service(ctxt, service_id)

+ 54 - 0
coriolis/conductor/rpc/utils.py

@@ -0,0 +1,54 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import time
+
+from oslo_log import log as logging
+
+from coriolis import utils
+
+
+LOG = logging.getLogger(__name__)
+
+
+def check_create_registration_for_service(
+        conductor_rpc, request_context, host, binary, topic, enabled=False,
+        mapped_regions=None, providers=None, specs=None, retry_period=30):
+    """ Checks with the conductor whether or not a service has already been
+    registered for this host and topic and creates one if not.
+    If the service is already registered, directs the conductor to refresh the
+    service status.
+    """
+    props = "host='%s', binary='%s', topic='%s'" % (host, binary, topic)
+    while True:
+        try:
+            # check is service already exists:
+            LOG.info(
+                "Checking with conductor if service with following porperties "
+                "was already registered: %s", props)
+            worker_service = conductor_rpc.check_service_registered(
+                request_context, host, binary, topic)
+            if worker_service:
+                LOG.info(
+                    "A service with properties %s has already been registered "
+                    "under ID '%s'. Updating existing registration.",
+                    props, worker_service['id'])
+                worker_service = conductor_rpc.update_service(
+                    request_context, worker_service['id'], updated_values={
+                        "providers": providers,
+                        "specs": specs})
+            else:
+                LOG.debug(
+                    "Attempting to register new service with properties: %s",
+                    props)
+                worker_service = conductor_rpc.register_service(
+                    request_context, host, binary, topic, enabled,
+                    mapped_regions=mapped_regions, providers=providers,
+                    specs=specs)
+            return worker_service
+        except Exception as ex:
+            LOG.warn(
+                "Failed to register service with specs %s. Retrying again in "
+                "%d seconds. Error was: %s", props, retry_period,
+                utils.get_exception_details())
+            time.sleep(retry_period)

+ 28 - 2
coriolis/constants.py

@@ -1,6 +1,8 @@
 # Copyright 2016 Cloudbase Solutions Srl
 # Copyright 2016 Cloudbase Solutions Srl
 # All Rights Reserved.
 # All Rights Reserved.
 
 
+DEFAULT_CORIOLIS_REGION_NAME = "Default Region"
+
 EXECUTION_STATUS_RUNNING = "RUNNING"
 EXECUTION_STATUS_RUNNING = "RUNNING"
 EXECUTION_STATUS_COMPLETED = "COMPLETED"
 EXECUTION_STATUS_COMPLETED = "COMPLETED"
 EXECUTION_STATUS_ERROR = "ERROR"
 EXECUTION_STATUS_ERROR = "ERROR"
@@ -24,6 +26,7 @@ FINALIZED_EXECUTION_STATUSES = [
 
 
 TASK_STATUS_SCHEDULED = "SCHEDULED"
 TASK_STATUS_SCHEDULED = "SCHEDULED"
 TASK_STATUS_PENDING = "PENDING"
 TASK_STATUS_PENDING = "PENDING"
+TASK_STATUS_STARTING = "STARTING"
 TASK_STATUS_UNSCHEDULED = "UNSCHEDULED"
 TASK_STATUS_UNSCHEDULED = "UNSCHEDULED"
 TASK_STATUS_RUNNING = "RUNNING"
 TASK_STATUS_RUNNING = "RUNNING"
 TASK_STATUS_COMPLETED = "COMPLETED"
 TASK_STATUS_COMPLETED = "COMPLETED"
@@ -36,9 +39,11 @@ TASK_STATUS_CANCELLING_AFTER_COMPLETION = "CANCELLING_AFTER_COMPLETION"
 TASK_STATUS_CANCELED_FOR_DEBUGGING = "CANCELED_FOR_DEBUGGING"
 TASK_STATUS_CANCELED_FOR_DEBUGGING = "CANCELED_FOR_DEBUGGING"
 TASK_STATUS_CANCELED_FROM_DEADLOCK = "STRANDED_AFTER_DEADLOCK"
 TASK_STATUS_CANCELED_FROM_DEADLOCK = "STRANDED_AFTER_DEADLOCK"
 TASK_STATUS_ON_ERROR_ONLY = "EXECUTE_ON_ERROR_ONLY"
 TASK_STATUS_ON_ERROR_ONLY = "EXECUTE_ON_ERROR_ONLY"
+TASK_STATUS_FAILED_TO_SCHEDULE = "FAILED_TO_SCHEDULE"
 
 
 ACTIVE_TASK_STATUSES = [
 ACTIVE_TASK_STATUSES = [
     TASK_STATUS_PENDING,
     TASK_STATUS_PENDING,
+    TASK_STATUS_STARTING,
     TASK_STATUS_RUNNING,
     TASK_STATUS_RUNNING,
     TASK_STATUS_CANCELLING,
     TASK_STATUS_CANCELLING,
     TASK_STATUS_CANCELLING_AFTER_COMPLETION
     TASK_STATUS_CANCELLING_AFTER_COMPLETION
@@ -50,7 +55,8 @@ CANCELED_TASK_STATUSES = [
     TASK_STATUS_FORCE_CANCELED,
     TASK_STATUS_FORCE_CANCELED,
     TASK_STATUS_CANCELED_AFTER_COMPLETION,
     TASK_STATUS_CANCELED_AFTER_COMPLETION,
     TASK_STATUS_CANCELED_FOR_DEBUGGING,
     TASK_STATUS_CANCELED_FOR_DEBUGGING,
-    TASK_STATUS_CANCELED_FROM_DEADLOCK
+    TASK_STATUS_CANCELED_FROM_DEADLOCK,
+    TASK_STATUS_FAILED_TO_SCHEDULE
 ]
 ]
 
 
 FINALIZED_TASK_STATUSES = [
 FINALIZED_TASK_STATUSES = [
@@ -61,7 +67,8 @@ FINALIZED_TASK_STATUSES = [
     TASK_STATUS_FORCE_CANCELED,
     TASK_STATUS_FORCE_CANCELED,
     TASK_STATUS_CANCELED_FOR_DEBUGGING,
     TASK_STATUS_CANCELED_FOR_DEBUGGING,
     TASK_STATUS_CANCELED_FROM_DEADLOCK,
     TASK_STATUS_CANCELED_FROM_DEADLOCK,
-    TASK_STATUS_CANCELED_AFTER_COMPLETION
+    TASK_STATUS_CANCELED_AFTER_COMPLETION,
+    TASK_STATUS_FAILED_TO_SCHEDULE
 ]
 ]
 
 
 TASK_TYPE_DEPLOY_MIGRATION_SOURCE_RESOURCES = (
 TASK_TYPE_DEPLOY_MIGRATION_SOURCE_RESOURCES = (
@@ -122,6 +129,13 @@ TASK_TYPE_VALIDATE_REPLICA_DEPLOYMENT_INPUTS = (
 TASK_TYPE_UPDATE_SOURCE_REPLICA = "UPDATE_SOURCE_REPLICA"
 TASK_TYPE_UPDATE_SOURCE_REPLICA = "UPDATE_SOURCE_REPLICA"
 TASK_TYPE_UPDATE_DESTINATION_REPLICA = "UPDATE_DESTINATION_REPLICA"
 TASK_TYPE_UPDATE_DESTINATION_REPLICA = "UPDATE_DESTINATION_REPLICA"
 
 
+TASK_PLATFORM_SOURCE = "source"
+TASK_PLATFORM_DESTINATION = "destination"
+TASK_PLATFORM_BILATERAL = "bilateral"
+
+PROVIDER_PLATFORM_SOURCE = "source"
+PROVIDER_PLATFORM_DESTINATION = "destination"
+
 PROVIDER_TYPE_IMPORT = 1
 PROVIDER_TYPE_IMPORT = 1
 PROVIDER_TYPE_EXPORT = 2
 PROVIDER_TYPE_EXPORT = 2
 PROVIDER_TYPE_REPLICA_IMPORT = 4
 PROVIDER_TYPE_REPLICA_IMPORT = 4
@@ -197,6 +211,8 @@ ENDPOINT_LOCK_NAME_FORMAT = "endpoint-%s"
 MIGRATION_LOCK_NAME_FORMAT = "migration-%s"
 MIGRATION_LOCK_NAME_FORMAT = "migration-%s"
 REPLICA_LOCK_NAME_FORMAT = "replica-%s"
 REPLICA_LOCK_NAME_FORMAT = "replica-%s"
 SCHEDULE_LOCK_NAME_FORMAT = "schedule-%s"
 SCHEDULE_LOCK_NAME_FORMAT = "schedule-%s"
+REGION_LOCK_NAME_FORMAT = "region-%s"
+SERVICE_LOCK_NAME_FORMAT = "service-%s"
 
 
 EXECUTION_TYPE_TO_ACTION_LOCK_NAME_FORMAT_MAP = {
 EXECUTION_TYPE_TO_ACTION_LOCK_NAME_FORMAT_MAP = {
     EXECUTION_TYPE_MIGRATION: MIGRATION_LOCK_NAME_FORMAT,
     EXECUTION_TYPE_MIGRATION: MIGRATION_LOCK_NAME_FORMAT,
@@ -205,3 +221,13 @@ EXECUTION_TYPE_TO_ACTION_LOCK_NAME_FORMAT_MAP = {
     EXECUTION_TYPE_REPLICA_UPDATE: REPLICA_LOCK_NAME_FORMAT,
     EXECUTION_TYPE_REPLICA_UPDATE: REPLICA_LOCK_NAME_FORMAT,
     EXECUTION_TYPE_REPLICA_DISKS_DELETE: REPLICA_LOCK_NAME_FORMAT
     EXECUTION_TYPE_REPLICA_DISKS_DELETE: REPLICA_LOCK_NAME_FORMAT
 }
 }
+
+SERVICE_STATUS_UP = "UP"
+SERVICE_STATUS_DOWN = "DOWN"
+SERVICE_STATUS_UNKNOWN = "UNKNOWN"
+
+SERVICE_MESSAGING_TOPIC_FORMAT = "%(main_topic)s.%(host)s"
+CONDUCTOR_MAIN_MESSAGING_TOPIC = "coriolis_conductor"
+WORKER_MAIN_MESSAGING_TOPIC = "coriolis_worker"
+SCHEDULER_MAIN_MESSAGING_TOPIC = "coriolis_scheduler"
+REPLICA_CRON_MAIN_MESSAGING_TOPIC = "coriolis_replica_cron_worker"

+ 476 - 23
coriolis/db/api.py

@@ -14,6 +14,7 @@ from sqlalchemy.sql import null
 
 
 from coriolis.db.sqlalchemy import models
 from coriolis.db.sqlalchemy import models
 from coriolis import exception
 from coriolis import exception
+from coriolis import utils
 
 
 CONF = cfg.CONF
 CONF = cfg.CONF
 db_options.set_defaults(CONF)
 db_options.set_defaults(CONF)
@@ -62,6 +63,38 @@ def _model_query(context, *args):
     return session.query(*args)
     return session.query(*args)
 
 
 
 
+def _update_sqlalchemy_object_fields(obj, updateable_fields, values_to_update):
+    """ Updates the given 'values_to_update' on the provided sqlalchemy object
+    as long as they are included as 'updateable_fields'.
+    :param obj: object: sqlalchemy object
+    :param updateable_fields: list(str): list of fields which are updateable
+    :param values_to_update: dict: dict with the key/vals to update
+    """
+    if not isinstance(values_to_update, dict):
+        raise exception.InvalidInput(
+            "Properties to update for DB object of type '%s' must be a dict, "
+            "got the following (type %s): %s" % (
+                type(obj), type(values_to_update), values_to_update))
+
+    non_updateable_fields = set(
+        values_to_update.keys()).difference(
+            set(updateable_fields))
+    if non_updateable_fields:
+        raise exception.Conflict(
+            "Fields %s for '%s' database cannot be updated. "
+            "Only updateable fields are: %s" % (
+                non_updateable_fields, type(obj), updateable_fields))
+
+    for field_name, field_val in values_to_update.items():
+        if not hasattr(obj, field_name):
+            raise exception.InvalidInput(
+                "No region field named '%s' to update." % field_name)
+        setattr(obj, field_name, field_val)
+    LOG.debug(
+        "Successfully updated the following fields on DB object "
+        "of type '%s': %s" % (type(obj), values_to_update.keys()))
+
+
 def _get_replica_schedules_filter(context, replica_id=None,
 def _get_replica_schedules_filter(context, replica_id=None,
                                   schedule_id=None, expired=True):
                                   schedule_id=None, expired=True):
     now = timeutils.utcnow()
     now = timeutils.utcnow()
@@ -91,7 +124,9 @@ def _soft_delete_aware_query(context, *args, **kwargs):
     :param show_deleted: if True, overrides context's show_deleted field.
     :param show_deleted: if True, overrides context's show_deleted field.
     """
     """
     query = _model_query(context, *args)
     query = _model_query(context, *args)
-    show_deleted = kwargs.get('show_deleted') or context.show_deleted
+    show_deleted = kwargs.get('show_deleted')
+    if context and context.show_deleted:
+        show_deleted = True
 
 
     if not show_deleted:
     if not show_deleted:
         query = query.filter_by(deleted_at=None)
         query = query.filter_by(deleted_at=None)
@@ -100,7 +135,8 @@ def _soft_delete_aware_query(context, *args, **kwargs):
 
 
 @enginefacade.reader
 @enginefacade.reader
 def get_endpoints(context):
 def get_endpoints(context):
-    q = _soft_delete_aware_query(context, models.Endpoint)
+    q = _soft_delete_aware_query(context, models.Endpoint).options(
+        orm.joinedload('mapped_regions'))
     if is_user_context(context):
     if is_user_context(context):
         q = q.filter(
         q = q.filter(
             models.Endpoint.project_id == context.tenant)
             models.Endpoint.project_id == context.tenant)
@@ -109,7 +145,8 @@ def get_endpoints(context):
 
 
 @enginefacade.reader
 @enginefacade.reader
 def get_endpoint(context, endpoint_id):
 def get_endpoint(context, endpoint_id):
-    q = _soft_delete_aware_query(context, models.Endpoint)
+    q = _soft_delete_aware_query(context, models.Endpoint).options(
+        orm.joinedload('mapped_regions'))
     if is_user_context(context):
     if is_user_context(context):
         q = q.filter(
         q = q.filter(
             models.Endpoint.project_id == context.tenant)
             models.Endpoint.project_id == context.tenant)
@@ -121,28 +158,118 @@ def get_endpoint(context, endpoint_id):
 def add_endpoint(context, endpoint):
 def add_endpoint(context, endpoint):
     endpoint.user_id = context.user
     endpoint.user_id = context.user
     endpoint.project_id = context.tenant
     endpoint.project_id = context.tenant
-    context.session.add(endpoint)
+    _session(context).add(endpoint)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
 def update_endpoint(context, endpoint_id, updated_values):
 def update_endpoint(context, endpoint_id, updated_values):
     endpoint = get_endpoint(context, endpoint_id)
     endpoint = get_endpoint(context, endpoint_id)
     if not endpoint:
     if not endpoint:
-        raise exception.NotFound("Endpoint not found")
-    for n in ["name", "description", "connection_info"]:
-        if n in updated_values:
-            setattr(endpoint, n, updated_values[n])
+        raise exception.NotFound("Endpoint with ID '%s' found" % endpoint_id)
+
+
+    if not isinstance(updated_values, dict):
+        raise exception.InvalidInput(
+            "Update payload for endpoints must be a dict. Got the following "
+            "(type: %s): %s" % (type(updated_values), updated_values))
+
+    def _try_unmap_regions(region_ids):
+         for region_to_unmap in region_ids:
+            try:
+                LOG.debug(
+                    "Attempting to unmap region '%s' from endpoint '%s'",
+                    region_to_unmap, endpoint_id)
+                delete_endpoint_region_mapping(
+                    context, endpoint_id, region_to_unmap)
+            except Exception as ex:
+                LOG.warn(
+                    "Exception occurred while attempting to unmap region '%s' "
+                    "from endpoint '%s'. Ignoring. Error was: %s",
+                    region_to_unmap, endpoint_id,
+                    utils.get_exception_details())
+
+    newly_mapped_regions = []
+    regions_to_unmap = []
+    # NOTE: `.pop()` required for  `_update_sqlalchemy_object_fields` call:
+    desired_region_mappings = updated_values.pop('mapped_regions', None)
+    if desired_region_mappings is not None:
+        # ensure all requested regions exist:
+        for region_id in desired_region_mappings:
+            region = get_region(context, region_id)
+            if not region:
+                raise exception.NotFound(
+                    "Could not find region with ID '%s' for associating "
+                    "with endpoint '%s' during update process." % (
+                        region_id, endpoint_id))
+
+        # get all existing mappings:
+        existing_region_mappings = [
+            mapping.region_id
+            for mapping in get_region_mappings_for_endpoint(
+                context, endpoint_id)]
+
+        # check and add new mappings:
+        to_map = set(
+            desired_region_mappings).difference(set(existing_region_mappings))
+        regions_to_unmap = set(
+            existing_region_mappings).difference(set(desired_region_mappings))
+
+        LOG.debug(
+            "Remapping regions for endpoint '%s' from %s to %s",
+            endpoint_id, existing_region_mappings, desired_region_mappings)
+
+        region_id = None
+        try:
+            for region_id in to_map:
+                mapping = models.EndpointRegionMapping()
+                mapping.region_id = region_id
+                mapping.endpoint_id = endpoint_id
+                add_endpoint_region_mapping(context, mapping)
+                newly_mapped_regions.append(region_id)
+        except Exception as ex:
+            LOG.warn(
+                "Exception occurred while adding region mapping for '%s' to "
+                "endpoint '%s'. Cleaning up created mappings (%s). Error was: "
+                "%s", region_id, endpoint_id, newly_mapped_regions,
+                utils.get_exception_details())
+            _try_unmap_regions(newly_mapped_regions)
+            raise
+
+
+    updateable_fields = ["name", "description", "connection_info"]
+    try:
+        _update_sqlalchemy_object_fields(
+            endpoint, updateable_fields, updated_values)
+    except Exception as ex:
+        LOG.warn(
+            "Exception occurred while updating fields of endpoint '%s'. "
+            "Cleaning ""up created mappings (%s). Error was: %s",
+            endpoint_id, newly_mapped_regions, utils.get_exception_details())
+        _try_unmap_regions(newly_mapped_regions)
+        raise
+
+    # remove all of the old region mappings:
+    LOG.debug(
+        "Unmapping the following regions during update of endpoint '%s': %s",
+        endpoint_id, regions_to_unmap)
+    _try_unmap_regions(regions_to_unmap)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
 def delete_endpoint(context, endpoint_id):
 def delete_endpoint(context, endpoint_id):
+    endpoint = get_endpoint(context, endpoint_id)
     args = {"id": endpoint_id}
     args = {"id": endpoint_id}
     if is_user_context(context):
     if is_user_context(context):
         args["project_id"] = context.tenant
         args["project_id"] = context.tenant
     count = _soft_delete_aware_query(context, models.Endpoint).filter_by(
     count = _soft_delete_aware_query(context, models.Endpoint).filter_by(
         **args).soft_delete()
         **args).soft_delete()
     if count == 0:
     if count == 0:
-        raise exception.NotFound("0 entries were soft deleted")
+        raise exception.NotFound("0 Endpoint entries were soft deleted")
+    # NOTE(aznashwan): many-to-many tables with soft deletion on either end of
+    # the association are not handled properly so we must manually delete each
+    # association ourselves:
+    for reg in endpoint.mapped_regions:
+        delete_endpoint_region_mapping(context, endpoint_id, reg.id)
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
@@ -181,7 +308,7 @@ def add_replica_tasks_execution(context, execution):
             action_id=execution.action.id).first()[0] or 0
             action_id=execution.action.id).first()[0] or 0
     execution.number = max_number + 1
     execution.number = max_number + 1
 
 
-    context.session.add(execution)
+    _session(context).add(execution)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -268,7 +395,7 @@ def add_replica_schedule(context, schedule, post_create_callable=None):
 
 
     if schedule.replica.project_id != context.tenant:
     if schedule.replica.project_id != context.tenant:
         raise exception.NotAuthorized()
         raise exception.NotAuthorized()
-    context.session.add(schedule)
+    _session(context).add(schedule)
     if post_create_callable:
     if post_create_callable:
         post_create_callable(context, schedule)
         post_create_callable(context, schedule)
 
 
@@ -280,7 +407,8 @@ def _get_replica_with_tasks_executions_options(q):
 @enginefacade.reader
 @enginefacade.reader
 def get_replicas(context,
 def get_replicas(context,
                  include_tasks_executions=False,
                  include_tasks_executions=False,
-                 include_info=False):
+                 include_info=False,
+                 to_dict=True):
     q = _soft_delete_aware_query(context, models.Replica)
     q = _soft_delete_aware_query(context, models.Replica)
     if include_tasks_executions:
     if include_tasks_executions:
         q = _get_replica_with_tasks_executions_options(q)
         q = _get_replica_with_tasks_executions_options(q)
@@ -291,7 +419,9 @@ def get_replicas(context,
         q = q.filter(
         q = q.filter(
             models.Replica.project_id == context.tenant)
             models.Replica.project_id == context.tenant)
     db_result = q.all()
     db_result = q.all()
-    return [i.to_dict(include_info=include_info) for i in db_result]
+    if to_dict:
+        return [i.to_dict(include_info=include_info) for i in db_result]
+    return db_result
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
@@ -322,7 +452,7 @@ def get_endpoint_replicas_count(context, endpoint_id):
 def add_replica(context, replica):
 def add_replica(context, replica):
     replica.user_id = context.user
     replica.user_id = context.user
     replica.project_id = context.tenant
     replica.project_id = context.tenant
-    context.session.add(replica)
+    _session(context).add(replica)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -358,7 +488,7 @@ def get_replica_migrations(context, replica_id):
 
 
 @enginefacade.reader
 @enginefacade.reader
 def get_migrations(context, include_tasks=False,
 def get_migrations(context, include_tasks=False,
-                   include_info=False):
+                   include_info=False, to_dict=True):
     q = _soft_delete_aware_query(context, models.Migration)
     q = _soft_delete_aware_query(context, models.Migration)
     if include_tasks:
     if include_tasks:
         q = _get_migration_task_query_options(q)
         q = _get_migration_task_query_options(q)
@@ -371,8 +501,9 @@ def get_migrations(context, include_tasks=False,
     if is_user_context(context):
     if is_user_context(context):
         args["project_id"] = context.tenant
         args["project_id"] = context.tenant
     result = q.filter_by(**args).all()
     result = q.filter_by(**args).all()
-    to_dict = [i.to_dict(include_info=include_info) for i in result]
-    return to_dict
+    if to_dict:
+        return [i.to_dict(include_info=include_info) for i in result]
+    return result
 
 
 
 
 def _get_tasks_with_details_options(query):
 def _get_tasks_with_details_options(query):
@@ -410,7 +541,7 @@ def get_migration(context, migration_id):
 def add_migration(context, migration):
 def add_migration(context, migration):
     migration.user_id = context.user
     migration.user_id = context.user
     migration.project_id = context.tenant
     migration.project_id = context.tenant
-    context.session.add(migration)
+    _session(context).add(migration)
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
@@ -541,10 +672,12 @@ def set_task_status(context, task_id, status, exception_details=None):
 
 
 
 
 @enginefacade.writer
 @enginefacade.writer
-def set_task_host(context, task_id, host, process_id):
+def set_task_host_properties(context, task_id, host=None, process_id=None):
     task = _get_task(context, task_id)
     task = _get_task(context, task_id)
-    task.host = host
-    task.process_id = process_id
+    if host:
+        task.host = host
+    if process_id:
+        task.process_id = process_id
 
 
 
 
 @enginefacade.reader
 @enginefacade.reader
@@ -559,7 +692,7 @@ def add_task_event(context, task_id, level, message):
     task_event.task_id = task_id
     task_event.task_id = task_id
     task_event.level = level
     task_event.level = level
     task_event.message = message
     task_event.message = message
-    context.session.add(task_event)
+    _session(context).add(task_event)
 
 
 
 
 def _get_progress_update(context, task_id, current_step):
 def _get_progress_update(context, task_id, current_step):
@@ -575,7 +708,7 @@ def add_task_progress_update(context, task_id, current_step, total_steps,
     task_progress_update = _get_progress_update(context, task_id, current_step)
     task_progress_update = _get_progress_update(context, task_id, current_step)
     if not task_progress_update:
     if not task_progress_update:
         task_progress_update = models.TaskProgressUpdate()
         task_progress_update = models.TaskProgressUpdate()
-        context.session.add(task_progress_update)
+        _session(context).add(task_progress_update)
 
 
     task_progress_update.task_id = task_id
     task_progress_update.task_id = task_id
     task_progress_update.current_step = current_step
     task_progress_update.current_step = current_step
@@ -617,3 +750,323 @@ def update_replica(context, replica_id, updated_values):
     # the oslo_db library uses this method for both the `created_at` and
     # the oslo_db library uses this method for both the `created_at` and
     # `updated_at` fields
     # `updated_at` fields
     setattr(replica, 'updated_at', timeutils.utcnow())
     setattr(replica, 'updated_at', timeutils.utcnow())
+
+
+@enginefacade.writer
+def add_region(context, region):
+    _session(context).add(region)
+
+
+@enginefacade.reader
+def get_regions(context):
+    q = _soft_delete_aware_query(context, models.Region)
+    q = q.options(orm.joinedload('mapped_endpoints'))
+    q = q.options(orm.joinedload('mapped_services'))
+    return q.all()
+
+
+@enginefacade.reader
+def get_region(context, region_id):
+    q = _soft_delete_aware_query(context, models.Region)
+    q = q.options(orm.joinedload('mapped_endpoints'))
+    q = q.options(orm.joinedload('mapped_services'))
+    return q.filter(
+        models.Region.id == region_id).first()
+
+
+@enginefacade.writer
+def update_region(context, region_id, updated_values):
+    if not region_id:
+        raise exception.InvalidInput(
+            "No region ID specified for updating.")
+    region = get_region(context, region_id)
+    if not region:
+        raise exception.NotFound(
+            "Region with ID '%s' does not exist." % region_id)
+
+    updateable_fields = ["name", "description", "enabled"]
+    _update_sqlalchemy_object_fields(
+        region, updateable_fields, updated_values)
+
+
+@enginefacade.writer
+def delete_region(context, region_id):
+    region = get_region(context, region_id)
+    count = _soft_delete_aware_query(context, models.Region).filter_by(
+        id=region_id).soft_delete()
+    if count == 0:
+        raise exception.NotFound("0 region entries were soft deleted")
+    # NOTE(aznashwan): many-to-many tables with soft deletion on either end of
+    # the association are not handled properly so we must manually delete each
+    # association ourselves:
+    for endp in region.mapped_endpoints:
+        delete_endpoint_region_mapping(context, endp.id, region_id)
+    for svc in region.mapped_services:
+        delete_service_region_mapping(context, svc.id, region_id)
+
+@enginefacade.writer
+def add_endpoint_region_mapping(context, endpoint_region_mapping):
+    region_id = endpoint_region_mapping.region_id
+    endpoint_id = endpoint_region_mapping.endpoint_id
+
+    if None in [region_id, endpoint_id]:
+        raise exception.InvalidInput(
+            "Provided endpoint region mapping params for the region ID "
+            "('%s') and the endpoint ID ('%s') must both be non-null." % (
+                region_id, endpoint_id))
+
+    _session(context).add(endpoint_region_mapping)
+
+
+@enginefacade.reader
+def get_endpoint_region_mapping(context, endpoint_id, region_id):
+    q = _soft_delete_aware_query(context, models.EndpointRegionMapping)
+    q = q.filter(
+        models.EndpointRegionMapping.region == region_id)
+    q = q.filter(
+        models.EndpointRegionMapping.endpoint_id == endpoint_id)
+    return q.all()
+
+
+@enginefacade.writer
+def delete_endpoint_region_mapping(context, endpoint_id, region_id):
+    args = {"endpoint_id": endpoint_id, "region_id": region_id}
+    # TODO(aznashwan): many-to-many realtionships have no sane way of
+    # supporting soft deletion from the sqlalchemy layer wihout
+    # writing join condictions, so we hard-`delete()` instead of
+    # `soft_delete()` util we find a better option:
+    count = _soft_delete_aware_query(
+        context, models.EndpointRegionMapping).filter_by(
+            **args).delete()
+    if count == 0:
+        raise exception.NotFound(
+            "There is no mapping between endpoint '%s' and region '%s'." % (
+                endpoint_id, region_id))
+    LOG.debug(
+        "Deleted mapping between endpoint '%s' and region '%s' from DB",
+        endpoint_id, region_id)
+
+
+@enginefacade.reader
+def get_region_mappings_for_endpoint(
+        context, endpoint_id, enabled_regions_only=False):
+    q = _soft_delete_aware_query(context, models.EndpointRegionMapping)
+    q = q.join(models.Region)
+    q = q.filter(
+        models.EndpointRegionMapping.endpoint_id == endpoint_id)
+    if enabled_regions_only:
+        q = q.filter(
+            models.Region.enabled == True)
+    return q.all()
+
+
+@enginefacade.reader
+def get_mapped_endpoints_for_region(context, region_id):
+    q = _soft_delete_aware_query(context, models.Endpoint)
+    q = q.join(models.EndpointRegionMapping)
+    q = q.filter(
+        models.EndpointRegionMapping.endpoint_id == region_id)
+    return q.all()
+
+
+@enginefacade.writer
+def add_service(context, service):
+    _session(context).add(service)
+
+
+@enginefacade.reader
+def get_services(context):
+    q = _soft_delete_aware_query(context, models.Service).options(
+        orm.joinedload('mapped_regions'))
+    return q.all()
+
+
+@enginefacade.reader
+def get_service(context, service_id):
+    q = _soft_delete_aware_query(context, models.Service).options(
+        orm.joinedload('mapped_regions'))
+    return q.filter(
+        models.Service.id == service_id).first()
+
+
+@enginefacade.reader
+def find_service(context, host, binary, topic=None):
+    args = {"host": host, "binary": binary}
+    if topic:
+        args["topic"] = topic
+    q = _soft_delete_aware_query(context, models.Service).options(
+         orm.joinedload('mapped_regions')).filter_by(**args)
+    return q.first()
+
+
+@enginefacade.writer
+def update_service(context, service_id, updated_values):
+    if not service_id:
+        raise exception.InvalidInput(
+            "No service ID specified for updating.")
+    service = get_service(context, service_id)
+    if not service:
+        raise exception.NotFound(
+            "Service with ID '%s' does not exist." % service_id)
+
+    if not isinstance(updated_values, dict):
+        raise exception.InvalidInput(
+            "Update payload for services must be a dict. Got the following "
+            "(type: %s): %s" % (type(updated_values), updated_values))
+
+    def _try_unmap_regions(region_ids):
+         for region_to_unmap in region_ids:
+            try:
+                LOG.debug(
+                    "Attempting to unmap region '%s' from service '%s'",
+                    region_to_unmap, service_id)
+                delete_service_region_mapping(
+                    context, service_id, region_to_unmap)
+            except Exception as ex:
+                LOG.warn(
+                    "Exception occurred while attempting to unmap region '%s' "
+                    "from service '%s'. Ignoring. Error was: %s",
+                    region_to_unmap, service_id,
+                    utils.get_exception_details())
+
+    newly_mapped_regions = []
+    regions_to_unmap = []
+    # NOTE: `.pop()` required for  `_update_sqlalchemy_object_fields` call:
+    desired_region_mappings = updated_values.pop('mapped_regions', None)
+    if desired_region_mappings is not None:
+        # ensure all requested regions exist:
+        for region_id in desired_region_mappings:
+            region = get_region(context, region_id)
+            if not region:
+                raise exception.NotFound(
+                    "Could not find region with ID '%s' for associating "
+                    "with serce '%s' during update process." % (
+                        region_id, service_id))
+
+        # get all existing mappings:
+        existing_region_mappings = [
+            mapping.region_id
+            for mapping in get_region_mappings_for_service(
+                context, service_id)]
+
+        # check and add new mappings:
+        to_map = set(
+            desired_region_mappings).difference(set(existing_region_mappings))
+        regions_to_unmap = set(
+            existing_region_mappings).difference(set(desired_region_mappings))
+
+        LOG.debug(
+            "Remapping regions for service '%s' from %s to %s",
+            service_id, existing_region_mappings, desired_region_mappings)
+
+        region_id = None
+        try:
+            for region_id in to_map:
+                mapping = models.ServiceRegionMapping()
+                mapping.region_id = region_id
+                mapping.service_id = service_id
+                add_service_region_mapping(context, mapping)
+                newly_mapped_regions.append(region_id)
+        except Exception as ex:
+            LOG.warn(
+                "Exception occurred while adding region mapping for '%s' to "
+                "service '%s'. Cleaning up created mappings (%s). Error was: "
+                "%s", region_id, service_id, newly_mapped_regions,
+                utils.get_exception_details())
+            _try_unmap_regions(newly_mapped_regions)
+            raise
+
+
+    updateable_fields = ["enabled", "status", "providers", "specs"]
+    try:
+        _update_sqlalchemy_object_fields(
+            service, updateable_fields, updated_values)
+    except Exception as ex:
+        LOG.warn(
+            "Exception occurred while updating fields of service '%s'. "
+            "Cleaning ""up created mappings (%s). Error was: %s",
+            service_id, newly_mapped_regions, utils.get_exception_details())
+        _try_unmap_regions(newly_mapped_regions)
+        raise
+
+    # remove all of the old region mappings:
+    LOG.debug(
+        "Unmapping the following regions during update of service '%s': %s",
+        service_id, regions_to_unmap)
+    _try_unmap_regions(regions_to_unmap)
+
+
+@enginefacade.writer
+def delete_service(context, service_id):
+    service = get_service(context, service_id)
+    count = _soft_delete_aware_query(context, models.Service).filter_by(
+        id=service_id).soft_delete()
+    if count == 0:
+        raise exception.NotFound("0 service entries were soft deleted")
+    # NOTE(aznashwan): many-to-many tables with soft deletion on either end of
+    # the association are not handled properly so we must manually delete each
+    # association ourselves:
+    for reg in service.mapped_regions:
+        delete_service_region_mapping(context, service_id, reg.id)
+
+
+@enginefacade.writer
+def add_service_region_mapping(context, service_region_mapping):
+    region_id = service_region_mapping.region_id
+    service_id = service_region_mapping.service_id
+
+    if None in [region_id, service_id]:
+        raise exception.InvalidInput(
+            "Provided service region mapping params for the region ID "
+            "('%s') and the service ID ('%s') must both be non-null." % (
+                region_id, service_id))
+
+    _session(context).add(service_region_mapping)
+
+
+@enginefacade.reader
+def get_service_region_mapping(context, service_id, region_id):
+    q = _soft_delete_aware_query(context, models.ServiceRegionMapping)
+    q = q.filter(
+        models.ServiceRegionMapping.region == region_id)
+    q = q.filter(
+        models.ServiceRegionMapping.service_id == service_id)
+    return q.all()
+
+
+@enginefacade.writer
+def delete_service_region_mapping(context, service_id, region_id):
+    args = {"service_id": service_id, "region_id": region_id}
+    # TODO(aznashwan): many-to-many realtionships have no sane way of
+    # supporting soft deletion from the sqlalchemy layer wihout
+    # writing join condictions, so we hard-`delete()` instead of
+    # `soft_delete()` util we find a better option:
+    count = _soft_delete_aware_query(
+        context, models.ServiceRegionMapping).filter_by(
+            **args).delete()
+    if count == 0:
+        raise exception.NotFound(
+            "There is no mapping between service '%s' and region '%s'." % (
+                service_id, region_id))
+
+
+@enginefacade.reader
+def get_region_mappings_for_service(
+        context, service_id, enabled_regions_only=False):
+    q = _soft_delete_aware_query(context, models.ServiceRegionMapping)
+    q = q.join(models.Region)
+    q = q.filter(
+        models.ServiceRegionMapping.service_id == service_id)
+    if enabled_regions_only:
+        q = q.filter(
+            models.Region.enabled == True)
+    return q.all()
+
+
+@enginefacade.reader
+def get_mapped_services_for_region(context, region_id):
+    q = _soft_delete_aware_query(context, models.Service)
+    q = q.join(models.ServiceRegionMapping)
+    q = q.filter(
+        models.ServiceRegionMapping.service_id == region_id)
+    return q.all()

+ 124 - 0
coriolis/db/sqlalchemy/migrate_repo/versions/014_adds_worker_service_regions.py

@@ -0,0 +1,124 @@
+# Copyright 2016 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import uuid
+
+import sqlalchemy
+
+
+def upgrade(migrate_engine):
+    meta = sqlalchemy.MetaData()
+    meta.bind = migrate_engine
+
+    sqlalchemy.Table(
+        'endpoint', meta, autoload=True)
+
+    tables = []
+
+    # declare region table:
+    tables.append(
+        sqlalchemy.Table(
+            'region',
+            meta,
+            sqlalchemy.Column('id', sqlalchemy.String(36), primary_key=True,
+                              default=lambda: str(uuid.uuid4())),
+            sqlalchemy.Column('name', sqlalchemy.String(255), nullable=False),
+            sqlalchemy.Column(
+                'description', sqlalchemy.String(1024), nullable=True),
+            sqlalchemy.Column('created_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted', sqlalchemy.String(36)),
+            sqlalchemy.Column(
+                'enabled', sqlalchemy.Boolean, nullable=True,
+                default=lambda: False)))
+
+    # declare endpoint-region-mapping table:
+    tables.append(
+        sqlalchemy.Table(
+            'endpoint_region_mapping',
+            meta,
+            sqlalchemy.Column(
+                'id',
+                sqlalchemy.String(36),
+                primary_key=True,
+                default=lambda: str(uuid.uuid4())),
+            sqlalchemy.Column(
+                'endpoint_id',
+                sqlalchemy.String(36),
+                sqlalchemy.ForeignKey('endpoint.id'),
+                nullable=False),
+            sqlalchemy.Column(
+                'region_id',
+                sqlalchemy.String(36),
+                sqlalchemy.ForeignKey('region.id'),
+                nullable=False),
+            sqlalchemy.Column('created_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted', sqlalchemy.String(36))))
+
+    # declare service table:
+    tables.append(
+        sqlalchemy.Table(
+            'service',
+            meta,
+            sqlalchemy.Column(
+                'id',
+                sqlalchemy.String(36),
+                primary_key=True,
+                default=lambda: str(uuid.uuid4())),
+            sqlalchemy.Column(
+                'enabled', sqlalchemy.Boolean, nullable=True,
+                default=lambda: False),
+            sqlalchemy.Column(
+                'host', sqlalchemy.String(255), nullable=False),
+            sqlalchemy.Column(
+                'binary', sqlalchemy.String(255), nullable=False),
+            sqlalchemy.Column(
+                'topic', sqlalchemy.String(255), nullable=False),
+            sqlalchemy.Column(
+                'status', sqlalchemy.String(255), nullable=False,
+                default=lambda: "UNKNOWN"),
+            sqlalchemy.Column(
+                'providers', sqlalchemy.Text(), nullable=False),
+            sqlalchemy.Column(
+                'specs', sqlalchemy.Text(), nullable=False),
+            sqlalchemy.Column('created_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted', sqlalchemy.String(36))))
+
+    # declare service-region mappings table:
+    tables.append(
+        sqlalchemy.Table(
+            'service_region_mapping',
+            meta,
+            sqlalchemy.Column(
+                'id',
+                sqlalchemy.String(36),
+                primary_key=True,
+                default=lambda: str(uuid.uuid4())),
+            sqlalchemy.Column(
+                'service_id',
+                sqlalchemy.String(36),
+                sqlalchemy.ForeignKey('service.id'),
+                nullable=False),
+            sqlalchemy.Column(
+                'region_id',
+                sqlalchemy.String(36),
+                sqlalchemy.ForeignKey('region.id'),
+                nullable=False),
+            sqlalchemy.Column('created_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('updated_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted_at', sqlalchemy.DateTime),
+            sqlalchemy.Column('deleted', sqlalchemy.String(36))))
+
+    for index, table in enumerate(tables):
+        try:
+            table.create()
+        except Exception:
+            # If an error occurs, drop all tables created so far to return
+            # to the previously existing state.
+            meta.drop_all(tables=tables[:index])
+            raise

+ 112 - 3
coriolis/db/sqlalchemy/models.py

@@ -7,7 +7,9 @@ from oslo_db.sqlalchemy import models
 import sqlalchemy
 import sqlalchemy
 from sqlalchemy.ext import declarative
 from sqlalchemy.ext import declarative
 from sqlalchemy import orm
 from sqlalchemy import orm
+from sqlalchemy import schema
 
 
+from coriolis import constants
 from coriolis.db.sqlalchemy import types
 from coriolis.db.sqlalchemy import types
 
 
 BASE = declarative.declarative_base()
 BASE = declarative.declarative_base()
@@ -277,6 +279,110 @@ class Migration(BaseTransferAction):
         })
         })
         return base
         return base
 
 
+class ServiceRegionMapping(
+        BASE, models.TimestampMixin, models.ModelBase, models.SoftDeleteMixin):
+    __tablename__ = "service_region_mapping"
+
+    id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        default=lambda: str(uuid.uuid4()),
+        nullable=False,
+        primary_key=True)
+
+    service_id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        sqlalchemy.ForeignKey('service.id'),
+        nullable=False)
+
+    region_id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        sqlalchemy.ForeignKey('region.id'),
+        nullable=False)
+
+
+class Service(BASE, models.TimestampMixin, models.ModelBase,
+              models.SoftDeleteMixin):
+    __tablename__ = "service"
+    __table_args__ = (
+        schema.UniqueConstraint("host", "topic", "deleted",
+                                name="uniq_services0host0topic0deleted"),
+        schema.UniqueConstraint("host", "binary", "deleted",
+                                name="uniq_services0host0binary0deleted"))
+
+    id = sqlalchemy.Column(
+        sqlalchemy.String(36), default=lambda: str(uuid.uuid4()),
+        primary_key=True)
+
+    host = sqlalchemy.Column(
+        sqlalchemy.String(255), nullable=False)
+    binary = sqlalchemy.Column(
+        sqlalchemy.String(255), nullable=False)
+    topic = sqlalchemy.Column(
+        sqlalchemy.String(255), nullable=True, default=None)
+    enabled = sqlalchemy.Column(
+        sqlalchemy.Boolean, nullable=False, default=lambda: False)
+    status = sqlalchemy.Column(
+        sqlalchemy.String(255), nullable=False,
+        default=lambda: constants.SERVICE_STATUS_UNKNOWN)
+    providers = sqlalchemy.Column(types.Json(), nullable=True)
+    specs = sqlalchemy.Column(types.Json(), nullable=True)
+    mapped_regions = orm.relationship(
+        'Region', back_populates='mapped_services',
+        secondary="service_region_mapping")
+
+
+class EndpointRegionMapping(
+        BASE, models.TimestampMixin, models.ModelBase, models.SoftDeleteMixin):
+    __tablename__ = "endpoint_region_mapping"
+
+    id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        default=lambda: str(uuid.uuid4()),
+        nullable=False,
+        primary_key=True)
+
+    endpoint_id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        sqlalchemy.ForeignKey('endpoint.id'),
+        nullable=False)
+
+    region_id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        sqlalchemy.ForeignKey('region.id'),
+        nullable=False)
+
+
+class Region(
+        BASE, models.TimestampMixin, models.ModelBase, models.SoftDeleteMixin):
+    __tablename__ = "region"
+
+    id = sqlalchemy.Column(
+        sqlalchemy.String(36),
+        default=lambda: str(uuid.uuid4()),
+        nullable=False,
+        primary_key=True)
+
+    name = sqlalchemy.Column(
+        sqlalchemy.String(255),
+        nullable=False)
+
+    description = sqlalchemy.Column(
+        sqlalchemy.String(1024),
+        nullable=True)
+
+    enabled = sqlalchemy.Column(
+        sqlalchemy.Boolean,
+        default=lambda: False,
+        nullable=False)
+
+    mapped_endpoints = orm.relationship(
+        'Endpoint', back_populates='mapped_regions',
+        secondary="endpoint_region_mapping")
+
+    mapped_services = orm.relationship(
+        'Service', back_populates='mapped_regions',
+        secondary="service_region_mapping")
+
 
 
 class Endpoint(BASE, models.TimestampMixin, models.ModelBase,
 class Endpoint(BASE, models.TimestampMixin, models.ModelBase,
                models.SoftDeleteMixin):
                models.SoftDeleteMixin):
@@ -294,11 +400,14 @@ class Endpoint(BASE, models.TimestampMixin, models.ModelBase,
     origin_actions = orm.relationship(
     origin_actions = orm.relationship(
         BaseTransferAction, backref=orm.backref('origin_endpoint'),
         BaseTransferAction, backref=orm.backref('origin_endpoint'),
         primaryjoin="and_(BaseTransferAction.origin_endpoint_id==Endpoint.id, "
         primaryjoin="and_(BaseTransferAction.origin_endpoint_id==Endpoint.id, "
-        "BaseTransferAction.deleted=='0')")
+                    "BaseTransferAction.deleted=='0')")
     destination_actions = orm.relationship(
     destination_actions = orm.relationship(
         BaseTransferAction, backref=orm.backref('destination_endpoint'),
         BaseTransferAction, backref=orm.backref('destination_endpoint'),
         primaryjoin="and_(BaseTransferAction.destination_endpoint_id=="
         primaryjoin="and_(BaseTransferAction.destination_endpoint_id=="
-        "Endpoint.id, BaseTransferAction.deleted=='0')")
+                    "Endpoint.id, BaseTransferAction.deleted=='0')")
+    mapped_regions = orm.relationship(
+        'Region', back_populates='mapped_endpoints',
+        secondary="endpoint_region_mapping")
 
 
 
 
 class ReplicaSchedule(BASE, models.TimestampMixin, models.ModelBase,
 class ReplicaSchedule(BASE, models.TimestampMixin, models.ModelBase,
@@ -317,7 +426,7 @@ class ReplicaSchedule(BASE, models.TimestampMixin, models.ModelBase,
     expiration_date = sqlalchemy.Column(
     expiration_date = sqlalchemy.Column(
         sqlalchemy.types.DateTime, nullable=True)
         sqlalchemy.types.DateTime, nullable=True)
     enabled = sqlalchemy.Column(
     enabled = sqlalchemy.Column(
-        sqlalchemy.Boolean, nullable=False, default=True)
+        sqlalchemy.Boolean, nullable=False, default=lambda: False)
     shutdown_instance = sqlalchemy.Column(
     shutdown_instance = sqlalchemy.Column(
         sqlalchemy.Boolean, nullable=False, default=False)
         sqlalchemy.Boolean, nullable=False, default=False)
     trust_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
     trust_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)

+ 3 - 2
coriolis/endpoints/api.py

@@ -10,9 +10,10 @@ class API(object):
         self._rpc_client = rpc_client.ConductorClient()
         self._rpc_client = rpc_client.ConductorClient()
 
 
     def create(self, ctxt, name, endpoint_type, description,
     def create(self, ctxt, name, endpoint_type, description,
-               connection_info):
+               connection_info, mapped_regions):
         return self._rpc_client.create_endpoint(
         return self._rpc_client.create_endpoint(
-            ctxt, name, endpoint_type, description, connection_info)
+            ctxt, name, endpoint_type, description, connection_info,
+            mapped_regions)
 
 
     def update(self, ctxt, endpoint_id, properties):
     def update(self, ctxt, endpoint_id, properties):
         return self._rpc_client.update_endpoint(
         return self._rpc_client.update_endpoint(

+ 39 - 0
coriolis/exception.py

@@ -148,6 +148,7 @@ class Invalid(CoriolisException):
     code = 400
     code = 400
     safe = True
     safe = True
 
 
+
 class InvalidCustomOSDetectTools(Invalid):
 class InvalidCustomOSDetectTools(Invalid):
     message = _("The provided custom OS detect tools are invalid.")
     message = _("The provided custom OS detect tools are invalid.")
 
 
@@ -245,6 +246,10 @@ class NotFound(CoriolisException):
     safe = True
     safe = True
 
 
 
 
+class RegionNotFound(NotFound):
+    message = _("The specified Coriolis region(s) could not be found.")
+
+
 class OSMorphingToolsNotFound(NotFound):
 class OSMorphingToolsNotFound(NotFound):
     message = _(
     message = _(
         'No OSMorphing tools were found for OS type "%(os_type)s" for this VM.'
         'No OSMorphing tools were found for OS type "%(os_type)s" for this VM.'
@@ -388,3 +393,37 @@ class UnrecognizedWorkerInitSystem(CoriolisException):
         "Could not determine init system for temporary worker VM. The image "
         "Could not determine init system for temporary worker VM. The image "
         "used for the worker VM must use systemd as an init system for "
         "used for the worker VM must use systemd as an init system for "
         "Coriolis to be able to use it for data Replication.")
         "Coriolis to be able to use it for data Replication.")
+
+
+class NoRegionError(CoriolisException):
+    safe = True
+    code = 503
+    message = _(
+        "No Coriolis region is avaialable to process this request at this "
+        "time.")
+
+
+class NoSuitableRegionError(NoRegionError):
+    message = _(
+        "No Coriolis Region(s) fitting the criteria of the required operation "
+        "could be found.")
+
+
+class NoServiceError(CoriolisException):
+    safe = True
+    code = 503
+    message = _(
+        "No service is avaialable to process this request at this time.")
+
+
+class NoWorkerServiceError(NoServiceError):
+    message = _(
+        "No Coriolis Worker Service(s) were found. Please ensure that "
+        "at least one or Coriolis Worker Service(s) are registered "
+        "within the Coriolis installation.")
+
+
+class NoSuitableWorkerServiceError(NoServiceError):
+    message = _(
+        "No suitable Coriolis Worker service was found which fits the "
+        "criteria for the required operation.")

+ 79 - 0
coriolis/policies/regions.py

@@ -0,0 +1,79 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+
+from oslo_policy import policy
+
+from coriolis.policies import base
+
+
+REGIONS_POLICY_PREFIX = "%s:regions" % base.CORIOLIS_POLICIES_PREFIX
+REGIONS_POLICY_DEFAULT_RULE = "rule:admin_or_owner"
+
+
+def get_regions_policy_label(rule_label):
+    return "%s:%s" % (
+        REGIONS_POLICY_PREFIX, rule_label)
+
+
+REGIONS_POLICY_DEFAULT_RULES = [
+    policy.DocumentedRuleDefault(
+        get_regions_policy_label('create'),
+        REGIONS_POLICY_DEFAULT_RULE,
+        "Create a region",
+        [
+            {
+                "path": "/regions",
+                "method": "POST"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_regions_policy_label('list'),
+        REGIONS_POLICY_DEFAULT_RULE,
+        "List regions",
+        [
+            {
+                "path": "/regions",
+                "method": "GET"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_regions_policy_label('show'),
+        REGIONS_POLICY_DEFAULT_RULE,
+        "Show details for region",
+        [
+            {
+                "path": "/regions/{region_id}",
+                "method": "GET"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_regions_policy_label('update'),
+        REGIONS_POLICY_DEFAULT_RULE,
+        "Update details for region",
+        [
+            {
+                "path": "/regions/{region_id}",
+                "method": "PUT"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_regions_policy_label('delete'),
+        REGIONS_POLICY_DEFAULT_RULE,
+        "Delete region",
+        [
+            {
+                "path": "/regions/{region_id}",
+                "method": "DELETE"
+            }
+        ]
+    )
+]
+
+
+def list_rules():
+    return REGIONS_POLICY_DEFAULT_RULES

+ 79 - 0
coriolis/policies/services.py

@@ -0,0 +1,79 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+
+from oslo_policy import policy
+
+from coriolis.policies import base
+
+
+SERVICES_POLICY_PREFIX = "%s:services" % base.CORIOLIS_POLICIES_PREFIX
+SERVICES_POLICY_DEFAULT_RULE = "rule:admin_or_owner"
+
+
+def get_services_policy_label(rule_label):
+    return "%s:%s" % (
+        SERVICES_POLICY_PREFIX, rule_label)
+
+
+SERVICES_POLICY_DEFAULT_RULES = [
+    policy.DocumentedRuleDefault(
+        get_services_policy_label('create'),
+        SERVICES_POLICY_DEFAULT_RULE,
+        "Create a service",
+        [
+            {
+                "path": "/services",
+                "method": "POST"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_services_policy_label('list'),
+        SERVICES_POLICY_DEFAULT_RULE,
+        "List services",
+        [
+            {
+                "path": "/services",
+                "method": "GET"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_services_policy_label('show'),
+        SERVICES_POLICY_DEFAULT_RULE,
+        "Show details for service",
+        [
+            {
+                "path": "/services/{service_id}",
+                "method": "GET"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_services_policy_label('update'),
+        SERVICES_POLICY_DEFAULT_RULE,
+        "Update details for service",
+        [
+            {
+                "path": "/services/{service_id}",
+                "method": "PUT"
+            }
+        ]
+    ),
+    policy.DocumentedRuleDefault(
+        get_services_policy_label('delete'),
+        SERVICES_POLICY_DEFAULT_RULE,
+        "Delete service",
+        [
+            {
+                "path": "/services/{service_id}",
+                "method": "DELETE"
+            }
+        ]
+    )
+]
+
+
+def list_rules():
+    return SERVICES_POLICY_DEFAULT_RULES

+ 4 - 2
coriolis/policy.py

@@ -14,9 +14,11 @@ from coriolis.policies import diagnostics
 from coriolis.policies import endpoints
 from coriolis.policies import endpoints
 from coriolis.policies import general
 from coriolis.policies import general
 from coriolis.policies import migrations
 from coriolis.policies import migrations
+from coriolis.policies import regions
 from coriolis.policies import replicas
 from coriolis.policies import replicas
 from coriolis.policies import replica_schedules
 from coriolis.policies import replica_schedules
 from coriolis.policies import replica_tasks_executions
 from coriolis.policies import replica_tasks_executions
+from coriolis.policies import services
 
 
 
 
 LOG = logging.getLogger(__name__)
 LOG = logging.getLogger(__name__)
@@ -26,7 +28,7 @@ _ENFORCER = None
 
 
 DEFAULT_POLICIES_MODULES = [
 DEFAULT_POLICIES_MODULES = [
     base, endpoints, general, migrations, replicas, replica_schedules,
     base, endpoints, general, migrations, replicas, replica_schedules,
-    replica_tasks_executions, diagnostics]
+    replica_tasks_executions, diagnostics, regions, services]
 
 
 
 
 def reset():
 def reset():
@@ -61,7 +63,7 @@ def check_policy_for_context(
     """ Checks the validity of the given action of the given target based on
     """ Checks the validity of the given action of the given target based on
     set policies.
     set policies.
     On success, returns a value where bool(val) == True.
     On success, returns a value where bool(val) == True.
-    On failure and if `do_raise` if False, returns False.
+    On failure and if `do_raise` is False, returns False.
     Raises `exception.PolicyNotAuthorized` or `exc` if the policy is
     Raises `exception.PolicyNotAuthorized` or `exc` if the policy is
     not authorized.
     not authorized.
     """
     """

+ 1 - 1
coriolis/providers/factory.py

@@ -61,7 +61,7 @@ def get_available_providers():
             provider_data = providers.get(cls.platform, {})
             provider_data = providers.get(cls.platform, {})
 
 
             provider_types = provider_data.get("types", [])
             provider_types = provider_data.get("types", [])
-            if (provider_class in cls.__bases__ and
+            if (provider_class in cls.__mro__ and
                     provider_type not in provider_types):
                     provider_type not in provider_types):
                 provider_types.append(provider_type)
                 provider_types.append(provider_type)
 
 

+ 0 - 0
coriolis/regions/__init__.py


+ 27 - 0
coriolis/regions/api.py

@@ -0,0 +1,27 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+from coriolis import utils
+from coriolis.conductor.rpc import client as rpc_client
+
+
+class API(object):
+    def __init__(self):
+        self._rpc_client = rpc_client.ConductorClient()
+
+    def create(self, ctxt, region_name, description, enabled=True):
+        return self._rpc_client.create_region(
+            ctxt, region_name, description=description, enabled=enabled)
+
+    def update(self, ctxt, region_id, updated_values):
+        return self._rpc_client.update_region(
+            ctxt, region_id, updated_values=updated_values)
+
+    def delete(self, ctxt, region_id):
+        self._rpc_client.delete_region(ctxt, region_id)
+
+    def get_regions(self, ctxt):
+        return self._rpc_client.get_regions(ctxt)
+
+    def get_region(self, ctxt, region_id):
+        return self._rpc_client.get_region(ctxt, region_id)

+ 3 - 2
coriolis/replica_cron/rpc/client.py

@@ -3,15 +3,16 @@
 
 
 import oslo_messaging as messaging
 import oslo_messaging as messaging
 
 
+from coriolis import constants
 from coriolis import rpc
 from coriolis import rpc
 
 
 VERSION = "1.0"
 VERSION = "1.0"
 
 
 
 
 class ReplicaCronClient(object):
 class ReplicaCronClient(object):
-    def __init__(self):
+    def __init__(self, topic=constants.REPLICA_CRON_MAIN_MESSAGING_TOPIC):
         target = messaging.Target(
         target = messaging.Target(
-            topic='coriolis_replica_cron_worker', version=VERSION)
+            topic=topic, version=VERSION)
         self._client = rpc.get_client(target)
         self._client = rpc.get_client(target)
 
 
     def register(self, ctxt, schedule):
     def register(self, ctxt, schedule):

+ 0 - 0
coriolis/scheduler/__init__.py


+ 0 - 0
coriolis/scheduler/filters/__init__.py


+ 22 - 0
coriolis/scheduler/filters/base.py

@@ -0,0 +1,22 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import abc
+
+from six import with_metaclass
+
+
+class BaseServiceFilter(with_metaclass(abc.ABCMeta)):
+
+    def is_service_acceptable(self, service):
+        return self.rate_service(service) > 0
+
+    def filter_services(self, services):
+        return [
+            service for service in services
+            if self.is_service_acceptable(service)]
+
+    @abc.abstractmethod
+    def rate_service(self, service):
+        """ Returns a rating out of 100 for the service. """
+        pass

+ 117 - 0
coriolis/scheduler/filters/trivial_filters.py

@@ -0,0 +1,117 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+from oslo_log import log as logging
+
+from coriolis import constants
+from coriolis.scheduler.filters import base
+
+
+LOG = logging.getLogger(__name__)
+
+
+class RegionsFilter(base.BaseServiceFilter):
+
+    def __init__(self, regions, any_region=False):
+        self._regions = regions
+        self._any_region = any_region
+
+    def __repr__(self):
+        return "<%s(regions=%s, any_region=%s)>" % (
+            self.__class__.__name__, self._regions, self._any_region)
+
+    def rate_service(self, service):
+        if not self._regions:
+            LOG.debug(
+                "No regions specified for this filter (%s). "
+                "Presuming service is valid.")
+            return 100
+
+        service_regions = [
+            region.id for region in service.mapped_regions]
+        found = []
+        missing = []
+        for region in self._regions:
+            if region in service_regions:
+                found.append(region)
+            else:
+                missing.append(region)
+
+        if not found:
+            LOG.debug(
+                "None of the requested regions are available on service (%s): "
+                "%s", service.id, self._regions)
+            return 0
+        if not self._any_region and missing:
+            LOG.debug(
+                "The following required regions are missing from service "
+                "with ID '%s': %s", service.id, missing)
+            return 0
+
+        return 100
+
+
+class TopicFilter(base.BaseServiceFilter):
+
+    def __init__(self, topic):
+        self._topic = topic
+
+    def __repr__(self):
+        return "<%s(topic=%s)>" % (
+            self.__class__.__name__, self._topic)
+
+    def rate_service(self, service):
+        if service.topic == self._topic:
+            return 100
+        return 0
+
+
+class EnabledFilter(base.BaseServiceFilter):
+
+    def __init__(self, enabled=True):
+        self._enabled = enabled
+
+    def __repr__(self):
+        return "<%s(enabled=%s)>" % (
+            self.__class__.__name__, self._enabled)
+
+    def rate_service(self, service):
+        if service.enabled == self._enabled:
+            return 100
+        return 0
+
+
+class ProviderTypesFilter(base.BaseServiceFilter):
+
+    def __init__(self, provider_requirements):
+        """ Filters based on requested provider capabilities.
+        :param provider_requirements: dict of the form {
+            "<platform_type>": [constants.PROVIDER_TYPE_*, ...]}
+        """
+        self._provider_requirements = provider_requirements
+
+    def __repr__(self):
+        return "<%s(provider_requirements=%s)>" % (
+            self.__class__.__name__, self._provider_requirements)
+
+    def rate_service(self, service):
+        for platform_type in self._provider_requirements:
+            if platform_type not in service.providers:
+                LOG.debug(
+                    "Service with ID '%s' does not have a provider for platform "
+                    "type '%s'", service.id, platform_type)
+                return 0
+
+            available_types = service.providers[
+                platform_type].get('types', [])
+            missing_types = [
+                typ for typ in self._provider_requirements[platform_type]
+                if typ not in available_types]
+            if missing_types:
+                LOG.debug(
+                    "Service with ID '%s' is missing the following required "
+                    "provider types for platform '%s': %s",
+                    service.id, platform_type, missing_types)
+                return 0
+
+        return 100

+ 0 - 0
coriolis/scheduler/rpc/__init__.py


+ 36 - 0
coriolis/scheduler/rpc/client.py

@@ -0,0 +1,36 @@
+# Copyright 2016 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+from oslo_config import cfg
+import oslo_messaging as messaging
+
+from coriolis import rpc
+
+VERSION = "1.0"
+
+scheduler_opts = [
+    cfg.IntOpt("scheduler_rpc_timeout",
+               help="Number of seconds until RPC calls to the "
+                    "scheduler timeout.")
+]
+
+CONF = cfg.CONF
+CONF.register_opts(scheduler_opts, 'scheduler')
+
+
+class SchedulerClient(object):
+    def __init__(self, timeout=None):
+        target = messaging.Target(topic='coriolis_scheduler', version=VERSION)
+        if timeout is None:
+            timeout = CONF.scheduler.scheduler_rpc_timeout
+        self._client = rpc.get_client(target, timeout=timeout)
+
+    def get_diagnostics(self, ctxt):
+        return self._client.call(ctxt, 'get_diagnostics')
+
+    def get_workers_for_specs(
+            self, ctxt, provider_requirements=None,
+            region_sets=None, enabled=None):
+        return self._client.call(
+            ctxt, 'get_workers_for_specs', region_sets=region_sets,
+            enabled=enabled, provider_requirements=provider_requirements)

+ 170 - 0
coriolis/scheduler/rpc/server.py

@@ -0,0 +1,170 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+import copy
+import functools
+import random
+import uuid
+
+from oslo_config import cfg
+from oslo_log import log as logging
+
+from coriolis import constants
+from coriolis import exception
+from coriolis import utils
+from coriolis.conductor.rpc import client as rpc_conductor_client
+from coriolis.scheduler.filters import trivial_filters
+from coriolis.db import api as db_api
+
+
+VERSION = "1.0"
+
+LOG = logging.getLogger(__name__)
+
+
+SCHEDULER_OPTS = []
+
+CONF = cfg.CONF
+CONF.register_opts(SCHEDULER_OPTS, 'scheduler')
+
+
+class SchedulerServerEndpoint(object):
+    def __init__(self):
+        self._rpc_conductor_client = rpc_conductor_client.ConductorClient()
+
+    def get_diagnostics(self, ctxt):
+        return utils.get_diagnostics_info()
+
+    def _get_all_worker_services(self, ctxt):
+        services = db_api.get_services(ctxt)
+        services = trivial_filters.TopicFilter(
+            constants.WORKER_MAIN_MESSAGING_TOPIC).filter_services(
+                services)
+        if not services:
+            raise exception.NoWorkerServiceError()
+
+        return services
+
+    def _get_weighted_filtered_services(
+            self, services, filters, minimum_per_filter_rating=1):
+        """ Returns list of services and their scores for the given filters.
+        Services which are rejected by any filter will be excluded.
+        """
+        if not filters:
+            LOG.warn(
+                "No filters provided. Presuming all services acceptable.")
+            return [(service, 100) for service in services]
+
+        scores = []
+
+
+        service_ids = [service.id for service in services]
+        LOG.debug(
+            "Running following filters on worker services '%s': %s",
+            service_ids, filters)
+        for service in services:
+            total_score = 0
+
+            acceptable = True
+            flt = None
+            for flt in filters:
+                rating = flt.rate_service(service)
+                if rating < minimum_per_filter_rating:
+                    acceptable = False
+                    break
+                total_score = total_score + rating
+            if not acceptable:
+                LOG.debug(
+                    "Service with ID '%s' was rejected by filter %r",
+                    service.id, flt)
+                continue
+
+            scores.append((service, total_score))
+
+        if not scores:
+            message = (
+                "None of the inspected Coriolis Worker services (IDs %s) "
+                "matched the requested filtering criteria (minimum score %d) "
+                "for the following required filters: %s" % (
+                    [s.id for s in services],
+                    minimum_per_filter_rating, filters))
+            raise exception.NoSuitableWorkerServiceError(message)
+
+        LOG.debug(
+            "Determined following scores for services based on filters '%s': "
+            "%s", filters, scores)
+
+        return sorted(
+            scores, key=lambda s: s[1], reverse=True)
+
+    def _filter_regions(
+            self, ctxt, region_ids, enabled=True, check_all_exist=True,
+            regions_cache=None):
+        found_regions = []
+        filtered_regions = []
+        regions = regions_cache
+        if not regions:
+            regions = db_api.get_regions(ctxt)
+        for region in regions:
+            if region.id in region_ids:
+                found_regions.append(region.id)
+                if region.enabled != enabled:
+                    continue
+                filtered_regions.append(region)
+
+        if check_all_exist:
+            missing_regions = set(region_ids).difference(
+                set(found_regions))
+            if missing_regions:
+                raise exception.RegionNotFound(
+                    "Failed to schedule job on regions %s as one or more "
+                    "of the proposed regions (%s) do not exist." % (
+                        region_ids, missing_regions))
+
+        return filtered_regions
+
+    def get_workers_for_specs(
+            self, ctxt, provider_requirements=None,
+            region_sets=None, enabled=None, filter_disabled_regions=True):
+        """ Returns a list of enabled Worker Services with the specified
+        parameters.
+        :param provider_requirements: dict of the form {
+            "<platform_type>": [constants.PROVIDER_TYPE_*, ...]}
+        param region_sets: list of lists of region IDs to filter for.
+        Services will be filtered unless they are associated with
+        at least one region in each region set.
+        """
+        filters = []
+        worker_services = self._get_all_worker_services(ctxt)
+
+        LOG.debug(
+            "Searching for Worker Services with specs: %s" % {
+                "provider_requirements": provider_requirements,
+                "region_sets": region_sets, "enabled": enabled})
+
+        if enabled is not None:
+            filters.append(trivial_filters.EnabledFilter(enabled=enabled))
+        if region_sets:
+            for region_set in region_sets:
+                filtered_regions = self._filter_regions(
+                    ctxt, region_set, enabled=filter_disabled_regions,
+                    check_all_exist=True)
+                if not filtered_regions:
+                    raise exception.NoSuitableRegionError(
+                        "None of the selected Regions (%s) are enabled or "
+                        "otherwise usable." % region_set)
+                filters.append(
+                    trivial_filters.RegionsFilter(region_set, any_region=True))
+        if provider_requirements:
+            filters.append(
+                trivial_filters.ProviderTypesFilter(provider_requirements))
+
+        filtered_services = self._get_weighted_filtered_services(
+            worker_services, filters)
+        LOG.info(
+            "Found Worker Services %s for specs: %s" % (
+                filtered_services, {
+                    "provider_requirements": provider_requirements,
+                    "region_sets": region_sets, "enabled": enabled}))
+
+        return [s[0] for s in filtered_services]

+ 0 - 0
coriolis/services/__init__.py


+ 29 - 0
coriolis/services/api.py

@@ -0,0 +1,29 @@
+# Copyright 2020 Cloudbase Solutions Srl
+# All Rights Reserved.
+
+from coriolis import utils
+from coriolis.conductor.rpc import client as rpc_client
+
+
+class API(object):
+    def __init__(self):
+        self._rpc_client = rpc_client.ConductorClient()
+
+    def create(
+            self, ctxt, host, binary, topic, mapped_regions,
+            enabled):
+        return self._rpc_client.register_service(
+            ctxt, host, binary, topic, enabled, mapped_regions)
+
+    def update(self, ctxt, service_id, updated_values):
+        return self._rpc_client.update_service(
+            ctxt, service_id, updated_values)
+
+    def delete(self, ctxt, region_id):
+        self._rpc_client.delete_service(ctxt, region_id)
+
+    def get_services(self, ctxt):
+        return self._rpc_client.get_services(ctxt)
+
+    def get_service(self, ctxt, service_id):
+        return self._rpc_client.get_service(ctxt, service_id)

+ 23 - 10
coriolis/tasks/base.py

@@ -52,27 +52,40 @@ class TaskRunner(with_metaclass(abc.ABCMeta)):
 
 
         return required_libs
         return required_libs
 
 
-    @property
-    @abc.abstractmethod
-    def required_task_info_properties(self):
+    @abc.abstractclassmethod
+    def get_required_task_info_properties(cls):
         """ Returns a list of the string fields which are required
         """ Returns a list of the string fields which are required
         to be present during the tasks' run method. """
         to be present during the tasks' run method. """
         pass
         pass
 
 
-    @property
-    @abc.abstractmethod
-    def returned_task_info_properties(self):
+    @abc.abstractclassmethod
+    def get_returned_task_info_properties(cls):
         """ Returns a list of the string fields which are returned by the
         """ Returns a list of the string fields which are returned by the
         tasks' run method to be added to the task info.
         tasks' run method to be added to the task info.
         """
         """
         pass
         pass
 
 
+    @abc.abstractclassmethod
+    def get_required_provider_types(cls):
+        """ Returns a dict with 'source/destination' as keys containing a list
+        of all the provider types (constants.PROVIDER_TYPE_*) required for the
+        task.
+        """
+        pass
+
+    @abc.abstractclassmethod
+    def get_required_platform(cls):
+        """ Returns whether the task operates on the source platform, the
+        destination, or both. (constants.TASK_PLATFORM_*)
+        """
+        pass
+
     @abc.abstractmethod
     @abc.abstractmethod
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         """ The actual logic run by the task.
         """ The actual logic run by the task.
         Should return a dict with all the fields declared by
         Should return a dict with all the fields declared by
-        'self.returned_task_info_properties'.
+        'self.get_returned_task_info_properties'.
         Must be implemented in all child classes.
         Must be implemented in all child classes.
         """
         """
         pass
         pass
@@ -84,7 +97,7 @@ class TaskRunner(with_metaclass(abc.ABCMeta)):
         NOTE: This should NOT modify the existing task_info in any way.
         NOTE: This should NOT modify the existing task_info in any way.
         """
         """
         missing_info_props = [
         missing_info_props = [
-            prop for prop in self.required_task_info_properties
+            prop for prop in self.get_required_task_info_properties()
             if prop not in task_info]
             if prop not in task_info]
         if missing_info_props:
         if missing_info_props:
             raise exception.CoriolisException(
             raise exception.CoriolisException(
@@ -102,7 +115,7 @@ class TaskRunner(with_metaclass(abc.ABCMeta)):
                     self.__class__, type(result), result))
                     self.__class__, type(result), result))
 
 
         missing_returns = [
         missing_returns = [
-            prop for prop in self.returned_task_info_properties
+            prop for prop in self.get_returned_task_info_properties()
             if prop not in result.keys()]
             if prop not in result.keys()]
         if missing_returns:
         if missing_returns:
             raise exception.CoriolisException(
             raise exception.CoriolisException(
@@ -114,7 +127,7 @@ class TaskRunner(with_metaclass(abc.ABCMeta)):
 
 
         undeclared_returns = [
         undeclared_returns = [
             prop for prop in result.keys()
             prop for prop in result.keys()
-            if prop not in self.returned_task_info_properties]
+            if prop not in self.get_returned_task_info_properties()]
         if undeclared_returns:
         if undeclared_returns:
             raise exception.CoriolisException(
             raise exception.CoriolisException(
                 "Task type '%s' returned the following undeclared "
                 "Task type '%s' returned the following undeclared "

+ 15 - 4
coriolis/tasks/migration_tasks.py

@@ -14,14 +14,25 @@ LOG = logging.getLogger(__name__)
 
 
 class GetOptimalFlavorTask(base.TaskRunner):
 class GetOptimalFlavorTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["export_info", "target_environment"]
         return ["export_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["instance_deployment_info"]
         return ["instance_deployment_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_INSTANCE_FLAVOR]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(

+ 47 - 12
coriolis/tasks/osmorphing_tasks.py

@@ -16,16 +16,29 @@ LOG = logging.getLogger(__name__)
 
 
 class OSMorphingTask(base.TaskRunner):
 class OSMorphingTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return [
         return [
             "osmorphing_info", "osmorphing_connection_info",
             "osmorphing_info", "osmorphing_connection_info",
             "user_scripts"]
             "user_scripts"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return []
         return []
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT],
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT],
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
 
 
@@ -64,16 +77,27 @@ class OSMorphingTask(base.TaskRunner):
 
 
 class DeployOSMorphingResourcesTask(base.TaskRunner):
 class DeployOSMorphingResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "instance_deployment_info"]
         return ["target_environment", "instance_deployment_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return [
         return [
             "os_morphing_resources", "osmorphing_info",
             "os_morphing_resources", "osmorphing_info",
             "osmorphing_connection_info"]
             "osmorphing_connection_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_OS_MORPHING]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -124,14 +148,25 @@ class DeployOSMorphingResourcesTask(base.TaskRunner):
 
 
 class DeleteOSMorphingResourcesTask(base.TaskRunner):
 class DeleteOSMorphingResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "os_morphing_resources"]
         return ["target_environment", "os_morphing_resources"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["os_morphing_resources", "osmorphing_connection_info"]
         return ["os_morphing_resources", "osmorphing_connection_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_OS_MORPHING]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(

+ 317 - 84
coriolis/tasks/replica_tasks.py

@@ -66,14 +66,25 @@ def _check_ensure_volumes_info_ordering(export_info, volumes_info):
 class GetInstanceInfoTask(base.TaskRunner):
 class GetInstanceInfoTask(base.TaskRunner):
     """ Task which gathers the export info for a VM.  """
     """ Task which gathers the export info for a VM.  """
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["source_environment"]
         return ["source_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["export_info"]
         return ["export_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -96,14 +107,25 @@ class GetInstanceInfoTask(base.TaskRunner):
 class ShutdownInstanceTask(base.TaskRunner):
 class ShutdownInstanceTask(base.TaskRunner):
     """ Task which shuts down a VM. """
     """ Task which shuts down a VM. """
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["source_environment"]
         return ["source_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return []
         return []
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -119,18 +141,32 @@ class ShutdownInstanceTask(base.TaskRunner):
 
 
 class ReplicateDisksTask(base.TaskRunner):
 class ReplicateDisksTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        # NOTE: considering Replication reads from one end (be it PMR minion
+        # or otherwise) to the disk writer minion on the destination,
+        # replicate_disks would need access to both:
+        return constants.TASK_PLATFORM_BILATERAL
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return [
         return [
             "export_info", "volumes_info", "source_environment",
             "export_info", "volumes_info", "source_environment",
             "source_resources",
             "source_resources",
             "source_resources_connection_info",
             "source_resources_connection_info",
             "target_resources_connection_info"]
             "target_resources_connection_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -181,15 +217,26 @@ class ReplicateDisksTask(base.TaskRunner):
 
 
 class DeployReplicaDisksTask(base.TaskRunner):
 class DeployReplicaDisksTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return [
         return [
             "export_info", "volumes_info", "target_environment"]
             "export_info", "volumes_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         target_environment = task_info['target_environment']
         target_environment = task_info['target_environment']
@@ -216,15 +263,25 @@ class DeployReplicaDisksTask(base.TaskRunner):
 
 
 class DeleteReplicaSourceDiskSnapshotsTask(base.TaskRunner):
 class DeleteReplicaSourceDiskSnapshotsTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return [
         return [
             "volumes_info", "source_environment"]
             "volumes_info", "source_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT]
+        }
 
 
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
@@ -253,15 +310,26 @@ class DeleteReplicaSourceDiskSnapshotsTask(base.TaskRunner):
 
 
 class DeleteReplicaDisksTask(base.TaskRunner):
 class DeleteReplicaDisksTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return [
         return [
             "volumes_info", "target_environment"]
             "volumes_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         event_manager = events.EventManager(event_handler)
         event_manager = events.EventManager(event_handler)
@@ -295,14 +363,25 @@ class DeleteReplicaDisksTask(base.TaskRunner):
 
 
 class DeployReplicaSourceResourcesTask(base.TaskRunner):
 class DeployReplicaSourceResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["source_environment", "export_info"]
         return ["source_environment", "export_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["source_resources", "source_resources_connection_info"]
         return ["source_resources", "source_resources_connection_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -355,14 +434,25 @@ class DeployReplicaSourceResourcesTask(base.TaskRunner):
 
 
 class DeleteReplicaSourceResourcesTask(base.TaskRunner):
 class DeleteReplicaSourceResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["source_environment", "source_resources"]
         return ["source_environment", "source_resources"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["source_resources", "source_resources_connection_info"]
         return ["source_resources", "source_resources_connection_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_REPLICA_EXPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -384,16 +474,27 @@ class DeleteReplicaSourceResourcesTask(base.TaskRunner):
 
 
 class DeployReplicaTargetResourcesTask(base.TaskRunner):
 class DeployReplicaTargetResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["export_info", "volumes_info", "target_environment"]
         return ["export_info", "volumes_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return [
         return [
             "volumes_info", "target_resources",
             "volumes_info", "target_resources",
             "target_resources_connection_info"]
             "target_resources_connection_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         target_environment = task_info["target_environment"]
         target_environment = task_info["target_environment"]
@@ -470,15 +571,26 @@ class DeployReplicaTargetResourcesTask(base.TaskRunner):
 
 
 class DeleteReplicaTargetResourcesTask(base.TaskRunner):
 class DeleteReplicaTargetResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_resources", "target_environment"]
         return ["target_resources", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return [
         return [
             "target_resources", "target_resources_connection_info"]
             "target_resources", "target_resources_connection_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -500,14 +612,25 @@ class DeleteReplicaTargetResourcesTask(base.TaskRunner):
 
 
 class DeployReplicaInstanceResourcesTask(base.TaskRunner):
 class DeployReplicaInstanceResourcesTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["export_info", "target_environment", "clone_disks"]
         return ["export_info", "target_environment", "clone_disks"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["instance_deployment_info"]
         return ["instance_deployment_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         target_environment = task_info["target_environment"]
         target_environment = task_info["target_environment"]
@@ -533,14 +656,25 @@ class DeployReplicaInstanceResourcesTask(base.TaskRunner):
 
 
 class FinalizeReplicaInstanceDeploymentTask(base.TaskRunner):
 class FinalizeReplicaInstanceDeploymentTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "instance_deployment_info"]
         return ["target_environment", "instance_deployment_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["transfer_result"]
         return ["transfer_result"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -564,14 +698,25 @@ class FinalizeReplicaInstanceDeploymentTask(base.TaskRunner):
 
 
 class CleanupFailedReplicaInstanceDeploymentTask(base.TaskRunner):
 class CleanupFailedReplicaInstanceDeploymentTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "instance_deployment_info"]
         return ["target_environment", "instance_deployment_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["instance_deployment_info"]
         return ["instance_deployment_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -591,14 +736,25 @@ class CleanupFailedReplicaInstanceDeploymentTask(base.TaskRunner):
 
 
 class CreateReplicaDiskSnapshotsTask(base.TaskRunner):
 class CreateReplicaDiskSnapshotsTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "export_info", "volumes_info"]
         return ["target_environment", "export_info", "volumes_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -624,14 +780,25 @@ class CreateReplicaDiskSnapshotsTask(base.TaskRunner):
 
 
 class DeleteReplicaTargetDiskSnapshotsTask(base.TaskRunner):
 class DeleteReplicaTargetDiskSnapshotsTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "export_info", "volumes_info"]
         return ["target_environment", "export_info", "volumes_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         export_info = task_info['export_info']
         export_info = task_info['export_info']
@@ -657,14 +824,25 @@ class DeleteReplicaTargetDiskSnapshotsTask(base.TaskRunner):
 
 
 class RestoreReplicaDiskSnapshotsTask(base.TaskRunner):
 class RestoreReplicaDiskSnapshotsTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["target_environment", "export_info", "volumes_info"]
         return ["target_environment", "export_info", "volumes_info"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info"]
         return ["volumes_info"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         provider = providers_factory.get_provider(
         provider = providers_factory.get_provider(
@@ -690,14 +868,25 @@ class RestoreReplicaDiskSnapshotsTask(base.TaskRunner):
 
 
 class ValidateReplicaExecutionSourceInputsTask(base.TaskRunner):
 class ValidateReplicaExecutionSourceInputsTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["source_environment"]
         return ["source_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return []
         return []
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_VALIDATE_REPLICA_EXPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         event_manager = events.EventManager(event_handler)
         event_manager = events.EventManager(event_handler)
@@ -720,14 +909,25 @@ class ValidateReplicaExecutionSourceInputsTask(base.TaskRunner):
 
 
 class ValidateReplicaExecutionDestinationInputsTask(base.TaskRunner):
 class ValidateReplicaExecutionDestinationInputsTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["export_info", "target_environment"]
         return ["export_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return []
         return []
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_VALIDATE_REPLICA_IMPORT]
+        }
+
     def _validate_provider_replica_import_input(
     def _validate_provider_replica_import_input(
             self, provider, ctxt, conn_info, target_environment, export_info):
             self, provider, ctxt, conn_info, target_environment, export_info):
         provider.validate_replica_import_input(
         provider.validate_replica_import_input(
@@ -769,14 +969,25 @@ class ValidateReplicaExecutionDestinationInputsTask(base.TaskRunner):
 
 
 class ValidateReplicaDeploymentParametersTask(base.TaskRunner):
 class ValidateReplicaDeploymentParametersTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["export_info", "target_environment"]
         return ["export_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return []
         return []
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_VALIDATE_REPLICA_IMPORT]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         event_manager = events.EventManager(event_handler)
         event_manager = events.EventManager(event_handler)
@@ -811,14 +1022,25 @@ class ValidateReplicaDeploymentParametersTask(base.TaskRunner):
 
 
 class UpdateSourceReplicaTask(base.TaskRunner):
 class UpdateSourceReplicaTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_SOURCE
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["volumes_info", "source_environment"]
         return ["volumes_info", "source_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info", "source_environment"]
         return ["volumes_info", "source_environment"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_SOURCE: [
+                constants.PROVIDER_TYPE_SOURCE_REPLICA_UPDATE]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         event_manager = events.EventManager(event_handler)
         event_manager = events.EventManager(event_handler)
@@ -868,14 +1090,25 @@ class UpdateSourceReplicaTask(base.TaskRunner):
 
 
 class UpdateDestinationReplicaTask(base.TaskRunner):
 class UpdateDestinationReplicaTask(base.TaskRunner):
 
 
-    @property
-    def required_task_info_properties(self):
+    @classmethod
+    def get_required_platform(cls):
+        return constants.TASK_PLATFORM_DESTINATION
+
+    @classmethod
+    def get_required_task_info_properties(cls):
         return ["export_info", "volumes_info", "target_environment"]
         return ["export_info", "volumes_info", "target_environment"]
 
 
-    @property
-    def returned_task_info_properties(self):
+    @classmethod
+    def get_returned_task_info_properties(cls):
         return ["volumes_info", "target_environment"]
         return ["volumes_info", "target_environment"]
 
 
+    @classmethod
+    def get_required_provider_types(cls):
+        return {
+            constants.PROVIDER_PLATFORM_DESTINATION: [
+                constants.PROVIDER_TYPE_DESTINATION_REPLICA_UPDATE]
+        }
+
     def _run(self, ctxt, instance, origin, destination, task_info,
     def _run(self, ctxt, instance, origin, destination, task_info,
              event_handler):
              event_handler):
         event_manager = events.EventManager(event_handler)
         event_manager = events.EventManager(event_handler)

+ 7 - 2
coriolis/utils.py

@@ -16,6 +16,7 @@ import re
 import socket
 import socket
 import string
 import string
 import subprocess
 import subprocess
+import sys
 import time
 import time
 import traceback
 import traceback
 import uuid
 import uuid
@@ -122,10 +123,10 @@ def get_diagnostics_info():
     # diagnostics.
     # diagnostics.
     packages = list(freeze.freeze())
     packages = list(freeze.freeze())
     return {
     return {
-        "application": os.path.basename(main.__file__),
+        "application": get_binary_name(),
         "packages": packages,
         "packages": packages,
         "os_info": _get_host_os_info(),
         "os_info": _get_host_os_info(),
-        "hostname": platform.node(),
+        "hostname": get_hostname(),
         "ip_addresses": _get_local_ips(),
         "ip_addresses": _get_local_ips(),
     }
     }
 
 
@@ -428,6 +429,10 @@ def get_hostname():
     return socket.gethostname()
     return socket.gethostname()
 
 
 
 
+def get_binary_name():
+    return os.path.basename(sys.argv[0])
+
+
 def get_exception_details():
 def get_exception_details():
     return traceback.format_exc()
     return traceback.format_exc()
 
 

+ 7 - 5
coriolis/worker/rpc/client.py

@@ -4,6 +4,7 @@
 from oslo_config import cfg
 from oslo_config import cfg
 import oslo_messaging as messaging
 import oslo_messaging as messaging
 
 
+from coriolis import constants
 from coriolis import rpc
 from coriolis import rpc
 
 
 VERSION = "1.0"
 VERSION = "1.0"
@@ -19,8 +20,9 @@ CONF.register_opts(worker_opts, 'worker')
 
 
 
 
 class WorkerClient(object):
 class WorkerClient(object):
-    def __init__(self, timeout=None):
-        target = messaging.Target(topic='coriolis_worker', version=VERSION)
+    def __init__(
+            self, timeout=None, topic=constants.WORKER_MAIN_MESSAGING_TOPIC):
+        target = messaging.Target(topic=topic, version=VERSION)
         if timeout is None:
         if timeout is None:
             timeout = CONF.worker.worker_rpc_timeout
             timeout = CONF.worker.worker_rpc_timeout
         self._client = rpc.get_client(target, timeout=timeout)
         self._client = rpc.get_client(target, timeout=timeout)
@@ -39,9 +41,6 @@ class WorkerClient(object):
         cctxt.call(ctxt, 'cancel_task', task_id=task_id, process_id=process_id,
         cctxt.call(ctxt, 'cancel_task', task_id=task_id, process_id=process_id,
                    force=force)
                    force=force)
 
 
-    def update_migration_status(self, ctxt, task_id, status):
-        self._client.call(ctxt, "update_migration_status", status=status)
-
     def get_endpoint_instances(self, ctxt, platform_name, connection_info,
     def get_endpoint_instances(self, ctxt, platform_name, connection_info,
                                source_environment, marker=None, limit=None,
                                source_environment, marker=None, limit=None,
                                instance_name_pattern=None):
                                instance_name_pattern=None):
@@ -128,3 +127,6 @@ class WorkerClient(object):
 
 
     def get_diagnostics(self, ctxt):
     def get_diagnostics(self, ctxt):
         return self._client.call(ctxt, 'get_diagnostics')
         return self._client.call(ctxt, 'get_diagnostics')
+
+    def get_service_status(self, ctxt):
+        return self._client.call(ctxt, 'get_service_status')

+ 87 - 6
coriolis/worker/rpc/server.py

@@ -16,7 +16,9 @@ import psutil
 from six.moves import queue
 from six.moves import queue
 
 
 from coriolis.conductor.rpc import client as rpc_conductor_client
 from coriolis.conductor.rpc import client as rpc_conductor_client
+from coriolis.conductor.rpc import utils as conductor_rpc_utils
 from coriolis import constants
 from coriolis import constants
+from coriolis import context
 from coriolis import events
 from coriolis import events
 from coriolis import exception
 from coriolis import exception
 from coriolis.providers import factory as providers_factory
 from coriolis.providers import factory as providers_factory
@@ -63,7 +65,36 @@ class _ConductorProviderEventHandler(events.BaseEventHandler):
 class WorkerServerEndpoint(object):
 class WorkerServerEndpoint(object):
     def __init__(self):
     def __init__(self):
         self._server = utils.get_hostname()
         self._server = utils.get_hostname()
-        self._rpc_conductor_client = rpc_conductor_client.ConductorClient()
+        self._service_registration = self._register_worker_service()
+
+    @property
+    def _rpc_conductor_client(self):
+        # NOTE(aznashwan): it is unsafe to fork processes with pre-instantiated
+        # oslo_messaging clients as the underlying eventlet thread queues will
+        # be invalidated. Considering this class both serves from a "main
+        # process" as well as forking child processes, it is safest to
+        # re-instantiate the client every time:
+        return rpc_conductor_client.ConductorClient()
+
+    def _register_worker_service(self):
+        host = utils.get_hostname()
+        binary = utils.get_binary_name()
+        dummy_context = context.RequestContext(
+            # TODO(aznashwan): we should ideally have a dedicated
+            # user/pass/tenant just for service registration.
+            # Either way, these values are not used and thus redundant.
+            "coriolis", "admin")
+        status = self.get_service_status(dummy_context)
+        service_registration = (
+            conductor_rpc_utils.check_create_registration_for_service(
+                self._rpc_conductor_client, dummy_context, host, binary,
+                constants.WORKER_MAIN_MESSAGING_TOPIC, enabled=True,
+                providers=status['providers'], specs=status['specs']))
+        LOG.info(
+            "Worker service is successfully registered with the following "
+            "parameters: %s", service_registration)
+        self._service_registration = service_registration
+        return service_registration
 
 
     def _check_remove_dir(self, path):
     def _check_remove_dir(self, path):
         try:
         try:
@@ -73,7 +104,22 @@ class WorkerServerEndpoint(object):
             # Ignore the exception
             # Ignore the exception
             LOG.exception(ex)
             LOG.exception(ex)
 
 
+    def get_service_status(self, ctxt):
+        diagnostics = self.get_diagnostics(ctxt)
+        status = {
+            "host": diagnostics["hostname"],
+            "binary": diagnostics["application"],
+            "topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
+            "providers": self.get_available_providers(ctxt),
+            "specs": diagnostics
+        }
+
+        return status
+
     def cancel_task(self, ctxt, task_id, process_id, force):
     def cancel_task(self, ctxt, task_id, process_id, force):
+        LOG.debug(
+            "Received request to cancel task '%s' (process %s)",
+            task_id, process_id)
         if not force and os.name == "nt":
         if not force and os.name == "nt":
             LOG.warn("Windows does not support SIGINT, performing a "
             LOG.warn("Windows does not support SIGINT, performing a "
                      "forced task termination")
                      "forced task termination")
@@ -95,8 +141,6 @@ class WorkerServerEndpoint(object):
                 "completed/error'd." % (
                 "completed/error'd." % (
                     process_id, task_id))
                     process_id, task_id))
             LOG.error(msg)
             LOG.error(msg)
-            self._rpc_conductor_client.confirm_task_cancellation(
-                ctxt, task_id, msg)
 
 
     def _handle_mp_log_events(self, p, mp_log_q):
     def _handle_mp_log_events(self, p, mp_log_q):
         while True:
         while True:
@@ -182,12 +226,29 @@ class WorkerServerEndpoint(object):
         extra_library_paths = self._get_extra_library_paths_for_providers(
         extra_library_paths = self._get_extra_library_paths_for_providers(
             ctxt, task_id, task_type, origin, destination)
             ctxt, task_id, task_type, origin, destination)
 
 
-        self._start_process_with_custom_library_paths(p, extra_library_paths)
-        LOG.info("Task process started: %s", task_id)
         try:
         try:
+            LOG.debug(
+                "Attempting to set task host on Conductor for task '%s'.",
+                task_id)
             self._rpc_conductor_client.set_task_host(
             self._rpc_conductor_client.set_task_host(
-                ctxt, task_id, self._server, p.pid)
+                ctxt, task_id, self._server)
+            LOG.debug(
+                "Attempting to start process for task with ID '%s'", task_id)
+            self._start_process_with_custom_library_paths(
+                p, extra_library_paths)
+            LOG.info("Task process started: %s", task_id)
+            LOG.debug(
+                "Attempting to set task process on Conductor for task '%s'.",
+                task_id)
+            self._rpc_conductor_client.set_task_process(
+                ctxt, task_id, p.pid)
+            LOG.debug(
+                "Successfully started and retported task process for task "
+                "with ID '%s'.", task_id)
         except (Exception, KeyboardInterrupt) as ex:
         except (Exception, KeyboardInterrupt) as ex:
+            LOG.debug(
+                "Exception occurred whilst setting host for task '%s'. Error "
+                "was: %s", task_id, utils.get_exception_details())
             # NOTE: because the task error classes are wrapped,
             # NOTE: because the task error classes are wrapped,
             # it's easiest to just check that the messages align:
             # it's easiest to just check that the messages align:
             cancelling_msg = exception.TASK_ALREADY_CANCELLING_EXCEPTION_FMT % {
             cancelling_msg = exception.TASK_ALREADY_CANCELLING_EXCEPTION_FMT % {
@@ -204,10 +265,17 @@ class WorkerServerEndpoint(object):
         p.join()
         p.join()
 
 
         if result is None:
         if result is None:
+            LOG.debug(
+                "No result from process (%s) running task '%s'. "
+                "Presuming task was cancelled.",
+                p.pid, task_id)
             raise exception.TaskProcessCanceledException(
             raise exception.TaskProcessCanceledException(
                 "Task was canceled.")
                 "Task was canceled.")
 
 
         if isinstance(result, str):
         if isinstance(result, str):
+            LOG.debug(
+                "Error message while running task '%s' on process "
+                "with PID '%s': %s", task_id, p.pid, result)
             raise exception.TaskProcessException(result)
             raise exception.TaskProcessException(result)
         return result
         return result
 
 
@@ -226,10 +294,23 @@ class WorkerServerEndpoint(object):
             self._rpc_conductor_client.task_completed(
             self._rpc_conductor_client.task_completed(
                 ctxt, task_id, task_result)
                 ctxt, task_id, task_result)
         except exception.TaskProcessCanceledException as ex:
         except exception.TaskProcessCanceledException as ex:
+            LOG.debug(
+                "Task with ID '%s' appears to have been cancelled. "
+                "Confirming cancellation to Conductor now. Error was: %s",
+                task_id, utils.get_exception_details())
             LOG.exception(ex)
             LOG.exception(ex)
             self._rpc_conductor_client.confirm_task_cancellation(
             self._rpc_conductor_client.confirm_task_cancellation(
                 ctxt, task_id, str(ex))
                 ctxt, task_id, str(ex))
+        except exception.NoSuitableWorkerServiceError as ex:
+            LOG.warn(
+                "A conductor-side scheduling error has occurred following the "
+                "completion of task '%s'. Ignoring. Error was: %s",
+                task_id, utils.get_exception_details())
         except Exception as ex:
         except Exception as ex:
+            LOG.debug(
+                "Task with ID '%s' has error'd out. Reporting error to "
+                "Conductor now. Error was: %s",
+                task_id, utils.get_exception_details())
             LOG.exception(ex)
             LOG.exception(ex)
             self._rpc_conductor_client.set_task_error(ctxt, task_id, str(ex))
             self._rpc_conductor_client.set_task_error(ctxt, task_id, str(ex))
 
 

+ 1 - 0
setup.cfg

@@ -29,6 +29,7 @@ console_scripts =
     coriolis-conductor = coriolis.cmd.conductor:main
     coriolis-conductor = coriolis.cmd.conductor:main
     coriolis-worker = coriolis.cmd.worker:main
     coriolis-worker = coriolis.cmd.worker:main
     coriolis-replica-cron = coriolis.cmd.replica_cron:main
     coriolis-replica-cron = coriolis.cmd.replica_cron:main
+    coriolis-scheduler= coriolis.cmd.scheduler:main
     coriolis-dbsync = coriolis.cmd.db_sync:main
     coriolis-dbsync = coriolis.cmd.db_sync:main
 
 
 [wheel]
 [wheel]