Просмотр исходного кода

Ensure RPC call connection cleanup across all components.

Due to some underlying issue related to the usage/implementation of the
oslo_messaging.RPCClient, cross-components connecions are not getting
properly closed, leading to numerous hanging connections and eventual
saturation of all ports available on the Rabbit host.

This PR adds a base class in the `coriolis.rpc` module which facilitates
automatic cleanup for all RPC calls.
Nashwan Azhari 5 лет назад
Родитель
Сommit
3cbf03a6f1

+ 91 - 87
coriolis/conductor/rpc/client.py

@@ -19,43 +19,44 @@ CONF = cfg.CONF
 CONF.register_opts(conductor_opts, 'conductor')
 CONF.register_opts(conductor_opts, 'conductor')
 
 
 
 
-class ConductorClient(object):
+class ConductorClient(rpc.BaseRPCClient):
     def __init__(self, timeout=None,
     def __init__(self, timeout=None,
                  topic=constants.CONDUCTOR_MAIN_MESSAGING_TOPIC):
                  topic=constants.CONDUCTOR_MAIN_MESSAGING_TOPIC):
         target = messaging.Target(topic=topic, version=VERSION)
         target = messaging.Target(topic=topic, version=VERSION)
         if timeout is None:
         if timeout is None:
             timeout = CONF.conductor.conductor_rpc_timeout
             timeout = CONF.conductor.conductor_rpc_timeout
-        self._client = rpc.get_client(target, timeout=timeout)
+        super(ConductorClient, self).__init__(
+            target, timeout=timeout)
 
 
     def create_endpoint(self, ctxt, name, endpoint_type, description,
     def create_endpoint(self, ctxt, name, endpoint_type, description,
                         connection_info, mapped_regions):
                         connection_info, mapped_regions):
-        return self._client.call(
+        return self._call(
             ctxt, 'create_endpoint', name=name, endpoint_type=endpoint_type,
             ctxt, 'create_endpoint', name=name, endpoint_type=endpoint_type,
             description=description, connection_info=connection_info,
             description=description, connection_info=connection_info,
             mapped_regions=mapped_regions)
             mapped_regions=mapped_regions)
 
 
     def update_endpoint(self, ctxt, endpoint_id, updated_values):
     def update_endpoint(self, ctxt, endpoint_id, updated_values):
-        return self._client.call(
+        return self._call(
             ctxt, 'update_endpoint',
             ctxt, 'update_endpoint',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             updated_values=updated_values)
             updated_values=updated_values)
 
 
     def get_endpoints(self, ctxt):
     def get_endpoints(self, ctxt):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoints')
             ctxt, 'get_endpoints')
 
 
     def get_endpoint(self, ctxt, endpoint_id):
     def get_endpoint(self, ctxt, endpoint_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint', endpoint_id=endpoint_id)
             ctxt, 'get_endpoint', endpoint_id=endpoint_id)
 
 
     def delete_endpoint(self, ctxt, endpoint_id):
     def delete_endpoint(self, ctxt, endpoint_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_endpoint', endpoint_id=endpoint_id)
             ctxt, 'delete_endpoint', endpoint_id=endpoint_id)
 
 
     def get_endpoint_instances(self, ctxt, endpoint_id, source_environment,
     def get_endpoint_instances(self, ctxt, endpoint_id, source_environment,
                                marker=None, limit=None,
                                marker=None, limit=None,
                                instance_name_pattern=None):
                                instance_name_pattern=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_instances',
             ctxt, 'get_endpoint_instances',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             source_environment=source_environment,
             source_environment=source_environment,
@@ -65,7 +66,7 @@ class ConductorClient(object):
 
 
     def get_endpoint_instance(
     def get_endpoint_instance(
             self, ctxt, endpoint_id, source_environment, instance_name):
             self, ctxt, endpoint_id, source_environment, instance_name):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_instance',
             ctxt, 'get_endpoint_instance',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             source_environment=source_environment,
             source_environment=source_environment,
@@ -73,83 +74,83 @@ class ConductorClient(object):
 
 
     def get_endpoint_source_options(
     def get_endpoint_source_options(
             self, ctxt, endpoint_id, env, option_names):
             self, ctxt, endpoint_id, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_source_options',
             ctxt, 'get_endpoint_source_options',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             env=env, option_names=option_names)
             env=env, option_names=option_names)
 
 
     def get_endpoint_destination_options(
     def get_endpoint_destination_options(
             self, ctxt, endpoint_id, env, option_names):
             self, ctxt, endpoint_id, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_destination_options',
             ctxt, 'get_endpoint_destination_options',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             env=env, option_names=option_names)
             env=env, option_names=option_names)
 
 
     def get_endpoint_networks(self, ctxt, endpoint_id, env):
     def get_endpoint_networks(self, ctxt, endpoint_id, env):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_networks',
             ctxt, 'get_endpoint_networks',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             env=env)
             env=env)
 
 
     def get_endpoint_storage(self, ctxt, endpoint_id, env):
     def get_endpoint_storage(self, ctxt, endpoint_id, env):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_storage',
             ctxt, 'get_endpoint_storage',
             endpoint_id=endpoint_id,
             endpoint_id=endpoint_id,
             env=env)
             env=env)
 
 
     def validate_endpoint_connection(self, ctxt, endpoint_id):
     def validate_endpoint_connection(self, ctxt, endpoint_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_connection',
             ctxt, 'validate_endpoint_connection',
             endpoint_id=endpoint_id)
             endpoint_id=endpoint_id)
 
 
     def validate_endpoint_target_environment(
     def validate_endpoint_target_environment(
             self, ctxt, endpoint_id, target_env):
             self, ctxt, endpoint_id, target_env):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_target_environment',
             ctxt, 'validate_endpoint_target_environment',
             endpoint_id=endpoint_id, target_env=target_env)
             endpoint_id=endpoint_id, target_env=target_env)
 
 
     def validate_endpoint_source_environment(
     def validate_endpoint_source_environment(
             self, ctxt, endpoint_id, source_env):
             self, ctxt, endpoint_id, source_env):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_source_environment',
             ctxt, 'validate_endpoint_source_environment',
             endpoint_id=endpoint_id, source_env=source_env)
             endpoint_id=endpoint_id, source_env=source_env)
 
 
     def get_available_providers(self, ctxt):
     def get_available_providers(self, ctxt):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_available_providers')
             ctxt, 'get_available_providers')
 
 
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_provider_schemas',
             ctxt, 'get_provider_schemas',
             platform_name=platform_name,
             platform_name=platform_name,
             provider_type=provider_type)
             provider_type=provider_type)
 
 
     def execute_replica_tasks(self, ctxt, replica_id,
     def execute_replica_tasks(self, ctxt, replica_id,
                               shutdown_instances=False):
                               shutdown_instances=False):
-        return self._client.call(
+        return self._call(
             ctxt, 'execute_replica_tasks', replica_id=replica_id,
             ctxt, 'execute_replica_tasks', replica_id=replica_id,
             shutdown_instances=shutdown_instances)
             shutdown_instances=shutdown_instances)
 
 
     def get_replica_tasks_executions(self, ctxt, replica_id,
     def get_replica_tasks_executions(self, ctxt, replica_id,
                                      include_tasks=False):
                                      include_tasks=False):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_replica_tasks_executions',
             ctxt, 'get_replica_tasks_executions',
             replica_id=replica_id,
             replica_id=replica_id,
             include_tasks=include_tasks)
             include_tasks=include_tasks)
 
 
     def get_replica_tasks_execution(self, ctxt, replica_id, execution_id):
     def get_replica_tasks_execution(self, ctxt, replica_id, execution_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_replica_tasks_execution', replica_id=replica_id,
             ctxt, 'get_replica_tasks_execution', replica_id=replica_id,
             execution_id=execution_id)
             execution_id=execution_id)
 
 
     def delete_replica_tasks_execution(self, ctxt, replica_id, execution_id):
     def delete_replica_tasks_execution(self, ctxt, replica_id, execution_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_replica_tasks_execution', replica_id=replica_id,
             ctxt, 'delete_replica_tasks_execution', replica_id=replica_id,
             execution_id=execution_id)
             execution_id=execution_id)
 
 
     def cancel_replica_tasks_execution(self, ctxt, replica_id, execution_id,
     def cancel_replica_tasks_execution(self, ctxt, replica_id, execution_id,
                                        force):
                                        force):
-        return self._client.call(
+        return self._call(
             ctxt, 'cancel_replica_tasks_execution', replica_id=replica_id,
             ctxt, 'cancel_replica_tasks_execution', replica_id=replica_id,
             execution_id=execution_id, force=force)
             execution_id=execution_id, force=force)
 
 
@@ -161,7 +162,7 @@ class ConductorClient(object):
                                  source_environment, destination_environment,
                                  source_environment, destination_environment,
                                  instances, network_map, storage_mappings,
                                  instances, network_map, storage_mappings,
                                  notes=None):
                                  notes=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'create_instances_replica',
             ctxt, 'create_instances_replica',
             origin_endpoint_id=origin_endpoint_id,
             origin_endpoint_id=origin_endpoint_id,
             destination_endpoint_id=destination_endpoint_id,
             destination_endpoint_id=destination_endpoint_id,
@@ -177,30 +178,30 @@ class ConductorClient(object):
             source_environment=source_environment)
             source_environment=source_environment)
 
 
     def get_replicas(self, ctxt, include_tasks_executions=False):
     def get_replicas(self, ctxt, include_tasks_executions=False):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_replicas',
             ctxt, 'get_replicas',
             include_tasks_executions=include_tasks_executions)
             include_tasks_executions=include_tasks_executions)
 
 
     def get_replica(self, ctxt, replica_id):
     def get_replica(self, ctxt, replica_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_replica', replica_id=replica_id)
             ctxt, 'get_replica', replica_id=replica_id)
 
 
     def delete_replica(self, ctxt, replica_id):
     def delete_replica(self, ctxt, replica_id):
-        self._client.call(
+        self._call(
             ctxt, 'delete_replica', replica_id=replica_id)
             ctxt, 'delete_replica', replica_id=replica_id)
 
 
     def delete_replica_disks(self, ctxt, replica_id):
     def delete_replica_disks(self, ctxt, replica_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_replica_disks', replica_id=replica_id)
             ctxt, 'delete_replica_disks', replica_id=replica_id)
 
 
     def get_migrations(self, ctxt, include_tasks=False,
     def get_migrations(self, ctxt, include_tasks=False,
                        include_info=False):
                        include_info=False):
-        return self._client.call(ctxt, 'get_migrations',
-                                 include_tasks=include_tasks,
-                                 include_info=include_info)
+        return self._call(
+            ctxt, 'get_migrations', include_tasks=include_tasks,
+            include_info=include_info)
 
 
     def get_migration(self, ctxt, migration_id):
     def get_migration(self, ctxt, migration_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_migration', migration_id=migration_id)
             ctxt, 'get_migration', migration_id=migration_id)
 
 
     def migrate_instances(self, ctxt, origin_endpoint_id,
     def migrate_instances(self, ctxt, origin_endpoint_id,
@@ -212,7 +213,7 @@ class ConductorClient(object):
                           replication_count, shutdown_instances=False,
                           replication_count, shutdown_instances=False,
                           notes=None, skip_os_morphing=False,
                           notes=None, skip_os_morphing=False,
                           user_scripts=None):
                           user_scripts=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'migrate_instances',
             ctxt, 'migrate_instances',
             origin_endpoint_id=origin_endpoint_id,
             origin_endpoint_id=origin_endpoint_id,
             destination_endpoint_id=destination_endpoint_id,
             destination_endpoint_id=destination_endpoint_id,
@@ -236,7 +237,7 @@ class ConductorClient(object):
                                  clone_disks=False,
                                  clone_disks=False,
                                  force=False, skip_os_morphing=False,
                                  force=False, skip_os_morphing=False,
                                  user_scripts=None):
                                  user_scripts=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'deploy_replica_instances', replica_id=replica_id,
             ctxt, 'deploy_replica_instances', replica_id=replica_id,
             instance_osmorphing_minion_pool_mappings=(
             instance_osmorphing_minion_pool_mappings=(
                 instance_osmorphing_minion_pool_mappings),
                 instance_osmorphing_minion_pool_mappings),
@@ -245,55 +246,58 @@ class ConductorClient(object):
             user_scripts=user_scripts)
             user_scripts=user_scripts)
 
 
     def delete_migration(self, ctxt, migration_id):
     def delete_migration(self, ctxt, migration_id):
-        self._client.call(
+        self._call(
             ctxt, 'delete_migration', migration_id=migration_id)
             ctxt, 'delete_migration', migration_id=migration_id)
 
 
     def cancel_migration(self, ctxt, migration_id, force):
     def cancel_migration(self, ctxt, migration_id, force):
-        self._client.call(
+        self._call(
             ctxt, 'cancel_migration', migration_id=migration_id, force=force)
             ctxt, 'cancel_migration', migration_id=migration_id, force=force)
 
 
     def set_task_host(self, ctxt, task_id, host):
     def set_task_host(self, ctxt, task_id, host):
-        self._client.call(
+        self._call(
             ctxt, 'set_task_host', task_id=task_id, host=host)
             ctxt, 'set_task_host', task_id=task_id, host=host)
 
 
     def set_task_process(self, ctxt, task_id, process_id):
     def set_task_process(self, ctxt, task_id, process_id):
-        self._client.call(
+        self._call(
             ctxt, 'set_task_process', task_id=task_id, process_id=process_id)
             ctxt, 'set_task_process', task_id=task_id, process_id=process_id)
 
 
     def task_completed(self, ctxt, task_id, task_result):
     def task_completed(self, ctxt, task_id, task_result):
-        self._client.call(
+        self._call(
             ctxt, 'task_completed', task_id=task_id, task_result=task_result)
             ctxt, 'task_completed', task_id=task_id, task_result=task_result)
 
 
     def confirm_task_cancellation(self, ctxt, task_id, cancellation_details):
     def confirm_task_cancellation(self, ctxt, task_id, cancellation_details):
-        self._client.call(
+        self._call(
             ctxt, 'confirm_task_cancellation', task_id=task_id,
             ctxt, 'confirm_task_cancellation', task_id=task_id,
             cancellation_details=cancellation_details)
             cancellation_details=cancellation_details)
 
 
     def set_task_error(self, ctxt, task_id, exception_details):
     def set_task_error(self, ctxt, task_id, exception_details):
-        self._client.call(ctxt, 'set_task_error', task_id=task_id,
-                          exception_details=exception_details)
+        self._call(
+            ctxt, 'set_task_error', task_id=task_id,
+            exception_details=exception_details)
 
 
     def task_event(self, ctxt, task_id, level, message):
     def task_event(self, ctxt, task_id, level, message):
-        self._client.cast(ctxt, 'task_event', task_id=task_id, level=level,
-                          message=message)
+        self._cast(
+            ctxt, 'task_event', task_id=task_id, level=level, message=message)
 
 
     def add_task_progress_update(self, ctxt, task_id, total_steps, message):
     def add_task_progress_update(self, ctxt, task_id, total_steps, message):
-        self._client.cast(ctxt, 'add_task_progress_update', task_id=task_id,
-                          total_steps=total_steps, message=message)
+        self._cast(
+            ctxt, 'add_task_progress_update', task_id=task_id,
+            total_steps=total_steps, message=message)
 
 
     def update_task_progress_update(self, ctxt, task_id, step,
     def update_task_progress_update(self, ctxt, task_id, step,
                                     total_steps, message):
                                     total_steps, message):
-        self._client.cast(ctxt, 'update_task_progress_update', task_id=task_id,
-                          step=step, total_steps=total_steps, message=message)
+        self._cast(
+            ctxt, 'update_task_progress_update', task_id=task_id,
+            step=step, total_steps=total_steps, message=message)
 
 
     def get_task_progress_step(self, ctxt, task_id):
     def get_task_progress_step(self, ctxt, task_id):
-        return self._client.call(ctxt, 'get_task_progress_step',
-                                 task_id=task_id)
+        return self._call(
+            ctxt, 'get_task_progress_step', task_id=task_id)
 
 
     def create_replica_schedule(self, ctxt, replica_id,
     def create_replica_schedule(self, ctxt, replica_id,
                                 schedule, enabled, exp_date,
                                 schedule, enabled, exp_date,
                                 shutdown_instance):
                                 shutdown_instance):
-        return self._client.call(
+        return self._call(
             ctxt, 'create_replica_schedule',
             ctxt, 'create_replica_schedule',
             replica_id=replica_id,
             replica_id=replica_id,
             schedule=schedule,
             schedule=schedule,
@@ -303,106 +307,106 @@ class ConductorClient(object):
 
 
     def update_replica_schedule(self, ctxt, replica_id, schedule_id,
     def update_replica_schedule(self, ctxt, replica_id, schedule_id,
                                 updated_values):
                                 updated_values):
-        return self._client.call(
+        return self._call(
             ctxt, 'update_replica_schedule',
             ctxt, 'update_replica_schedule',
             replica_id=replica_id,
             replica_id=replica_id,
             schedule_id=schedule_id,
             schedule_id=schedule_id,
             updated_values=updated_values)
             updated_values=updated_values)
 
 
     def delete_replica_schedule(self, ctxt, replica_id, schedule_id):
     def delete_replica_schedule(self, ctxt, replica_id, schedule_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_replica_schedule',
             ctxt, 'delete_replica_schedule',
             replica_id=replica_id,
             replica_id=replica_id,
             schedule_id=schedule_id)
             schedule_id=schedule_id)
 
 
     def get_replica_schedules(self, ctxt, replica_id=None, expired=True):
     def get_replica_schedules(self, ctxt, replica_id=None, expired=True):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_replica_schedules',
             ctxt, 'get_replica_schedules',
             replica_id=replica_id, expired=expired)
             replica_id=replica_id, expired=expired)
 
 
     def get_replica_schedule(self, ctxt, replica_id,
     def get_replica_schedule(self, ctxt, replica_id,
                              schedule_id, expired=True):
                              schedule_id, expired=True):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_replica_schedule',
             ctxt, 'get_replica_schedule',
             replica_id=replica_id,
             replica_id=replica_id,
             schedule_id=schedule_id,
             schedule_id=schedule_id,
             expired=expired)
             expired=expired)
 
 
     def update_replica(self, ctxt, replica_id, updated_properties):
     def update_replica(self, ctxt, replica_id, updated_properties):
-        return self._client.call(
+        return self._call(
             ctxt, 'update_replica',
             ctxt, 'update_replica',
             replica_id=replica_id,
             replica_id=replica_id,
             updated_properties=updated_properties)
             updated_properties=updated_properties)
 
 
     def get_diagnostics(self, ctxt):
     def get_diagnostics(self, ctxt):
-        return self._client.call(ctxt, 'get_diagnostics')
+        return self._call(ctxt, 'get_diagnostics')
 
 
     def get_all_diagnostics(self, ctxt):
     def get_all_diagnostics(self, ctxt):
-        return self._client.call(ctxt, 'get_all_diagnostics')
+        return self._call(ctxt, 'get_all_diagnostics')
 
 
     def create_region(
     def create_region(
             self, ctxt, region_name, description="", enabled=True):
             self, ctxt, region_name, description="", enabled=True):
-        return self._client.call(
+        return self._call(
             ctxt, 'create_region',
             ctxt, 'create_region',
             region_name=region_name,
             region_name=region_name,
             description=description,
             description=description,
             enabled=enabled)
             enabled=enabled)
 
 
     def get_regions(self, ctxt):
     def get_regions(self, ctxt):
-        return self._client.call(ctxt, 'get_regions')
+        return self._call(ctxt, 'get_regions')
 
 
     def get_region(self, ctxt, region_id):
     def get_region(self, ctxt, region_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_region', region_id=region_id)
             ctxt, 'get_region', region_id=region_id)
 
 
     def update_region(self, ctxt, region_id, updated_values):
     def update_region(self, ctxt, region_id, updated_values):
-        return self._client.call(
+        return self._call(
             ctxt, 'update_region',
             ctxt, 'update_region',
             region_id=region_id,
             region_id=region_id,
             updated_values=updated_values)
             updated_values=updated_values)
 
 
     def delete_region(self, ctxt, region_id):
     def delete_region(self, ctxt, region_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_region', region_id=region_id)
             ctxt, 'delete_region', region_id=region_id)
 
 
     def register_service(
     def register_service(
             self, ctxt, host, binary, topic, enabled, mapped_regions,
             self, ctxt, host, binary, topic, enabled, mapped_regions,
             providers=None, specs=None):
             providers=None, specs=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'register_service', host=host, binary=binary,
             ctxt, 'register_service', host=host, binary=binary,
             topic=topic, enabled=enabled, mapped_regions=mapped_regions,
             topic=topic, enabled=enabled, mapped_regions=mapped_regions,
             providers=providers, specs=specs)
             providers=providers, specs=specs)
 
 
     def check_service_registered(self, ctxt, host, binary, topic):
     def check_service_registered(self, ctxt, host, binary, topic):
-        return self._client.call(
+        return self._call(
             ctxt, 'check_service_registered', host=host, binary=binary,
             ctxt, 'check_service_registered', host=host, binary=binary,
             topic=topic)
             topic=topic)
 
 
     def refresh_service_status(self, ctxt, service_id):
     def refresh_service_status(self, ctxt, service_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'refresh_service_status', service_id=service_id)
             ctxt, 'refresh_service_status', service_id=service_id)
 
 
     def get_services(self, ctxt):
     def get_services(self, ctxt):
-        return self._client.call(ctxt, 'get_services')
+        return self._call(ctxt, 'get_services')
 
 
     def get_service(self, ctxt, service_id):
     def get_service(self, ctxt, service_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_service', service_id=service_id)
             ctxt, 'get_service', service_id=service_id)
 
 
     def update_service(self, ctxt, service_id, updated_values):
     def update_service(self, ctxt, service_id, updated_values):
-        return self._client.call(
+        return self._call(
             ctxt, 'update_service', service_id=service_id,
             ctxt, 'update_service', service_id=service_id,
             updated_values=updated_values)
             updated_values=updated_values)
 
 
     def delete_service(self, ctxt, service_id):
     def delete_service(self, ctxt, service_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_service', service_id=service_id)
             ctxt, 'delete_service', service_id=service_id)
 
 
     def create_minion_pool(
     def create_minion_pool(
             self, ctxt, name, endpoint_id, pool_platform, pool_os_type,
             self, ctxt, name, endpoint_id, pool_platform, pool_os_type,
             environment_options, minimum_minions, maximum_minions,
             environment_options, minimum_minions, maximum_minions,
             minion_max_idle_time, minion_retention_strategy, notes=None):
             minion_max_idle_time, minion_retention_strategy, notes=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'create_minion_pool', name=name, endpoint_id=endpoint_id,
             ctxt, 'create_minion_pool', name=name, endpoint_id=endpoint_id,
             pool_platform=pool_platform, pool_os_type=pool_os_type,
             pool_platform=pool_platform, pool_os_type=pool_os_type,
             environment_options=environment_options,
             environment_options=environment_options,
@@ -413,89 +417,89 @@ class ConductorClient(object):
             notes=notes)
             notes=notes)
 
 
     def set_up_shared_minion_pool_resources(self, ctxt, minion_pool_id):
     def set_up_shared_minion_pool_resources(self, ctxt, minion_pool_id):
-        return self._client.call(
+        return self._call(
             ctxt, "set_up_shared_minion_pool_resources",
             ctxt, "set_up_shared_minion_pool_resources",
             minion_pool_id=minion_pool_id)
             minion_pool_id=minion_pool_id)
 
 
     def tear_down_shared_minion_pool_resources(
     def tear_down_shared_minion_pool_resources(
             self, ctxt, minion_pool_id, force=False):
             self, ctxt, minion_pool_id, force=False):
-        return self._client.call(
+        return self._call(
             ctxt, "tear_down_shared_minion_pool_resources",
             ctxt, "tear_down_shared_minion_pool_resources",
             minion_pool_id=minion_pool_id, force=force)
             minion_pool_id=minion_pool_id, force=force)
 
 
     def allocate_minion_pool_machines(self, ctxt, minion_pool_id):
     def allocate_minion_pool_machines(self, ctxt, minion_pool_id):
-        return self._client.call(
+        return self._call(
             ctxt, "allocate_minion_pool_machines",
             ctxt, "allocate_minion_pool_machines",
             minion_pool_id=minion_pool_id)
             minion_pool_id=minion_pool_id)
 
 
     def deallocate_minion_pool_machines(
     def deallocate_minion_pool_machines(
             self, ctxt, minion_pool_id, force=False):
             self, ctxt, minion_pool_id, force=False):
-        return self._client.call(
+        return self._call(
             ctxt, "deallocate_minion_pool_machines",
             ctxt, "deallocate_minion_pool_machines",
             minion_pool_id=minion_pool_id,
             minion_pool_id=minion_pool_id,
             force=force)
             force=force)
 
 
     def get_minion_pools(self, ctxt):
     def get_minion_pools(self, ctxt):
-        return self._client.call(ctxt, 'get_minion_pools')
+        return self._call(ctxt, 'get_minion_pools')
 
 
     def get_minion_pool(self, ctxt, minion_pool_id):
     def get_minion_pool(self, ctxt, minion_pool_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_minion_pool', minion_pool_id=minion_pool_id)
             ctxt, 'get_minion_pool', minion_pool_id=minion_pool_id)
 
 
     def update_minion_pool(self, ctxt, minion_pool_id, updated_values):
     def update_minion_pool(self, ctxt, minion_pool_id, updated_values):
-        return self._client.call(
+        return self._call(
             ctxt, 'update_minion_pool',
             ctxt, 'update_minion_pool',
             minion_pool_id=minion_pool_id, updated_values=updated_values)
             minion_pool_id=minion_pool_id, updated_values=updated_values)
 
 
     def delete_minion_pool(self, ctxt, minion_pool_id):
     def delete_minion_pool(self, ctxt, minion_pool_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_minion_pool', minion_pool_id=minion_pool_id)
             ctxt, 'delete_minion_pool', minion_pool_id=minion_pool_id)
 
 
     def get_minion_pool_lifecycle_executions(
     def get_minion_pool_lifecycle_executions(
             self, ctxt, minion_pool_id, include_tasks=False):
             self, ctxt, minion_pool_id, include_tasks=False):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_minion_pool_lifecycle_executions',
             ctxt, 'get_minion_pool_lifecycle_executions',
             minion_pool_id=minion_pool_id, include_tasks=include_tasks)
             minion_pool_id=minion_pool_id, include_tasks=include_tasks)
 
 
     def get_minion_pool_lifecycle_execution(
     def get_minion_pool_lifecycle_execution(
             self, ctxt, minion_pool_id, execution_id):
             self, ctxt, minion_pool_id, execution_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_minion_pool_lifecycle_execution',
             ctxt, 'get_minion_pool_lifecycle_execution',
             minion_pool_id=minion_pool_id, execution_id=execution_id)
             minion_pool_id=minion_pool_id, execution_id=execution_id)
 
 
     def delete_minion_pool_lifecycle_execution(
     def delete_minion_pool_lifecycle_execution(
             self, ctxt, minion_pool_id, execution_id):
             self, ctxt, minion_pool_id, execution_id):
-        return self._client.call(
+        return self._call(
             ctxt, 'delete_minion_pool_lifecycle_execution',
             ctxt, 'delete_minion_pool_lifecycle_execution',
             minion_pool_id=minion_pool_id, execution_id=execution_id)
             minion_pool_id=minion_pool_id, execution_id=execution_id)
 
 
     def cancel_minion_pool_lifecycle_execution(
     def cancel_minion_pool_lifecycle_execution(
             self, ctxt, minion_pool_id, execution_id, force):
             self, ctxt, minion_pool_id, execution_id, force):
-        return self._client.call(
+        return self._call(
             ctxt, 'cancel_minion_pool_lifecycle_execution',
             ctxt, 'cancel_minion_pool_lifecycle_execution',
             minion_pool_id=minion_pool_id, execution_id=execution_id,
             minion_pool_id=minion_pool_id, execution_id=execution_id,
             force=force)
             force=force)
 
 
     def get_endpoint_source_minion_pool_options(
     def get_endpoint_source_minion_pool_options(
             self, ctxt, endpoint_id, env, option_names):
             self, ctxt, endpoint_id, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_source_minion_pool_options',
             ctxt, 'get_endpoint_source_minion_pool_options',
             endpoint_id=endpoint_id, env=env, option_names=option_names)
             endpoint_id=endpoint_id, env=env, option_names=option_names)
 
 
     def get_endpoint_destination_minion_pool_options(
     def get_endpoint_destination_minion_pool_options(
             self, ctxt, endpoint_id, env, option_names):
             self, ctxt, endpoint_id, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_destination_minion_pool_options',
             ctxt, 'get_endpoint_destination_minion_pool_options',
             endpoint_id=endpoint_id, env=env, option_names=option_names)
             endpoint_id=endpoint_id, env=env, option_names=option_names)
 
 
     def validate_endpoint_source_minion_pool_options(
     def validate_endpoint_source_minion_pool_options(
             self, ctxt, endpoint_id, pool_environment):
             self, ctxt, endpoint_id, pool_environment):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_source_minion_pool_options',
             ctxt, 'validate_endpoint_source_minion_pool_options',
             endpoint_id=endpoint_id, pool_environment=pool_environment)
             endpoint_id=endpoint_id, pool_environment=pool_environment)
 
 
     def validate_endpoint_destination_minion_pool_options(
     def validate_endpoint_destination_minion_pool_options(
             self, ctxt, endpoint_id, pool_environment):
             self, ctxt, endpoint_id, pool_environment):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_destination_minion_pool_options',
             ctxt, 'validate_endpoint_destination_minion_pool_options',
             endpoint_id=endpoint_id, pool_environment=pool_environment)
             endpoint_id=endpoint_id, pool_environment=pool_environment)

+ 18 - 44
coriolis/conductor/rpc/server.py

@@ -196,8 +196,8 @@ def minion_pool_synchronized(func):
 class ConductorServerEndpoint(object):
 class ConductorServerEndpoint(object):
     def __init__(self):
     def __init__(self):
         self._licensing_client = licensing_client.LicensingClient.from_env()
         self._licensing_client = licensing_client.LicensingClient.from_env()
-        self._scheduler_client_instance = None
         self._worker_client_instance = None
         self._worker_client_instance = None
+        self._scheduler_client_instance = None
         self._replica_cron_client_instance = None
         self._replica_cron_client_instance = None
 
 
     # NOTE(aznashwan): it is unsafe to fork processes with pre-instantiated
     # NOTE(aznashwan): it is unsafe to fork processes with pre-instantiated
@@ -207,21 +207,21 @@ class ConductorServerEndpoint(object):
     # instantiate the clients only when needed:
     # instantiate the clients only when needed:
     @property
     @property
     def _worker_client(self):
     def _worker_client(self):
-        if not getattr(self, '_worker_client_instance'):
+        if not self._worker_client_instance:
             self._worker_client_instance = (
             self._worker_client_instance = (
                 rpc_worker_client.WorkerClient())
                 rpc_worker_client.WorkerClient())
         return self._worker_client_instance
         return self._worker_client_instance
 
 
     @property
     @property
     def _scheduler_client(self):
     def _scheduler_client(self):
-        if not getattr(self, '_scheduler_client_instance'):
+        if not self._scheduler_client_instance:
             self._scheduler_client_instance = (
             self._scheduler_client_instance = (
                 rpc_scheduler_client.SchedulerClient())
                 rpc_scheduler_client.SchedulerClient())
         return self._scheduler_client_instance
         return self._scheduler_client_instance
 
 
     @property
     @property
     def _replica_cron_client(self):
     def _replica_cron_client(self):
-        if not getattr(self, '_replica_cron_client_instance'):
+        if not self._replica_cron_client_instance:
             self._replica_cron_client_instance = (
             self._replica_cron_client_instance = (
                 rpc_cron_client.ReplicaCronClient())
                 rpc_cron_client.ReplicaCronClient())
         return self._replica_cron_client_instance
         return self._replica_cron_client_instance
@@ -234,29 +234,11 @@ class ConductorServerEndpoint(object):
         worker_diagnostics = []
         worker_diagnostics = []
         for worker_service in self._scheduler_client.get_workers_for_specs(
         for worker_service in self._scheduler_client.get_workers_for_specs(
                 ctxt):
                 ctxt):
-            worker_rpc = self._get_rpc_client_for_service(worker_service)
+            worker_rpc = self._get_worker_rpc_for_host(worker_service['host'])
             diagnostics.append(worker_rpc.get_diagnostics(ctxt))
             diagnostics.append(worker_rpc.get_diagnostics(ctxt))
 
 
         return diagnostics
         return diagnostics
 
 
-    def _get_rpc_client_for_service(self, service, *client_args, **client_kwargs):
-        rpc_client_class = RPC_TOPIC_TO_CLIENT_CLASS_MAP.get(service.topic)
-        if not rpc_client_class:
-            raise exception.NotFound(
-                "No RPC client class for service with topic '%s'." % (
-                    service.topic))
-
-        topic = service.topic
-        if service.topic == constants.WORKER_MAIN_MESSAGING_TOPIC:
-            # NOTE: coriolis.service.MessagingService-type services (such
-            # as the worker), always have a dedicated per-host queue
-            # which can be used to target the service:
-            topic = constants.SERVICE_MESSAGING_TOPIC_FORMAT % ({
-                "main_topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
-                "host": service.host})
-
-        return rpc_client_class(*client_args, topic=topic, **client_kwargs)
-
     def _get_any_worker_service(self, ctxt, random_choice=False, raw_dict=False):
     def _get_any_worker_service(self, ctxt, random_choice=False, raw_dict=False):
         services = self._scheduler_client.get_workers_for_specs(ctxt)
         services = self._scheduler_client.get_workers_for_specs(ctxt)
         if not services:
         if not services:
@@ -268,13 +250,8 @@ class ConductorServerEndpoint(object):
             return service
             return service
         return db_api.get_service(ctxt, service['id'])
         return db_api.get_service(ctxt, service['id'])
 
 
-    def _get_worker_rpc_for_host(self, host, *client_args, **client_kwargs):
-        rpc_client_class = RPC_TOPIC_TO_CLIENT_CLASS_MAP[
-            constants.WORKER_MAIN_MESSAGING_TOPIC]
-        topic = constants.SERVICE_MESSAGING_TOPIC_FORMAT % ({
-            "main_topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
-            "host": host})
-        return rpc_client_class(*client_args, topic=topic, **client_kwargs)
+    def _get_worker_rpc_for_host(self, worker_host, **client_kwargs):
+        return rpc_worker_client.WorkerClient(host=worker_host, **client_kwargs)
 
 
     def _get_worker_service_rpc_for_specs(
     def _get_worker_service_rpc_for_specs(
             self, ctxt, provider_requirements=None, region_sets=None,
             self, ctxt, provider_requirements=None, region_sets=None,
@@ -300,12 +277,11 @@ class ConductorServerEndpoint(object):
         selected_service = services[0]
         selected_service = services[0]
         if random_choice:
         if random_choice:
             selected_service = random.choice(services)
             selected_service = random.choice(services)
-        service = db_api.get_service(ctxt, selected_service["id"])
 
 
         LOG.info(
         LOG.info(
             "Was offered Worker Service with ID '%s' for requirements: %s",
             "Was offered Worker Service with ID '%s' for requirements: %s",
-            service.id, requirements_str)
-        return self._get_rpc_client_for_service(service)
+            selected_service['id'], requirements_str)
+        return self._get_worker_rpc_for_host(selected_service['host'])
 
 
     def _check_delete_reservation_for_transfer(self, transfer_action):
     def _check_delete_reservation_for_transfer(self, transfer_action):
         action_id = transfer_action.base_id
         action_id = transfer_action.base_id
@@ -618,14 +594,14 @@ class ConductorServerEndpoint(object):
     def get_available_providers(self, ctxt):
     def get_available_providers(self, ctxt):
         # TODO(aznashwan): merge list of all providers from all
         # TODO(aznashwan): merge list of all providers from all
         # worker services:
         # worker services:
-        worker_rpc = self._get_rpc_client_for_service(
-            self._get_any_worker_service(ctxt))
+        worker_rpc = self._get_worker_rpc_for_host(
+            self._get_any_worker_service(ctxt)['host'])
         return worker_rpc.get_available_providers(ctxt)
         return worker_rpc.get_available_providers(ctxt)
 
 
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
         # TODO(aznashwan): merge or version/namespace schemas for each worker?
         # TODO(aznashwan): merge or version/namespace schemas for each worker?
-        worker_rpc = self._get_rpc_client_for_service(
-            self._get_any_worker_service(ctxt))
+        worker_rpc = self._get_worker_rpc_for_host(
+            self._get_any_worker_service(ctxt)['host'])
         return worker_rpc.get_provider_schemas(
         return worker_rpc.get_provider_schemas(
             ctxt, platform_name, provider_type)
             ctxt, platform_name, provider_type)
 
 
@@ -793,7 +769,7 @@ class ConductorServerEndpoint(object):
                         retry_count=scheduling_retry_count,
                         retry_count=scheduling_retry_count,
                         retry_period=scheduling_retry_period)
                         retry_period=scheduling_retry_period)
                     worker_rpc.begin_task(
                     worker_rpc.begin_task(
-                        ctxt, server=None,
+                        ctxt,
                         task_id=task.id,
                         task_id=task.id,
                         task_type=task.task_type,
                         task_type=task.task_type,
                         origin=origin,
                         origin=origin,
@@ -2548,10 +2524,8 @@ class ConductorServerEndpoint(object):
                             # cancellation call to prevent the conductor from
                             # cancellation call to prevent the conductor from
                             # hanging an excessive amount of time:
                             # hanging an excessive amount of time:
                             task.host, timeout=10)
                             task.host, timeout=10)
-                        # NOTE: the RPC client is already prepped with the
-                        # right topic so we pass 'None' as the server host:
                         worker_rpc.cancel_task(
                         worker_rpc.cancel_task(
-                            ctxt, None, task.id, task.process_id, force)
+                            ctxt, task.id, task.process_id, force)
                     except (Exception, KeyboardInterrupt):
                     except (Exception, KeyboardInterrupt):
                         msg = (
                         msg = (
                             "Failed to send cancellation request for task '%s'"
                             "Failed to send cancellation request for task '%s'"
@@ -2957,7 +2931,7 @@ class ConductorServerEndpoint(object):
                 worker_rpc = self._get_worker_service_rpc_for_task(
                 worker_rpc = self._get_worker_service_rpc_for_task(
                     ctxt, task, origin_endpoint, destination_endpoint)
                     ctxt, task, origin_endpoint, destination_endpoint)
                 worker_rpc.begin_task(
                 worker_rpc.begin_task(
-                    ctxt, server=None,
+                    ctxt,
                     task_id=task.id,
                     task_id=task.id,
                     task_type=task.task_type,
                     task_type=task.task_type,
                     origin=origin,
                     origin=origin,
@@ -4001,7 +3975,7 @@ class ConductorServerEndpoint(object):
         service.status = constants.SERVICE_STATUS_UP
         service.status = constants.SERVICE_STATUS_UP
 
 
         if None in (providers, specs):
         if None in (providers, specs):
-            worker_rpc = self._get_rpc_client_for_service(service)
+            worker_rpc = self._get_worker_rpc_for_host(service['host'])
             status = worker_rpc.get_service_status(ctxt)
             status = worker_rpc.get_service_status(ctxt)
 
 
             service.providers = status["providers"]
             service.providers = status["providers"]
@@ -4050,7 +4024,7 @@ class ConductorServerEndpoint(object):
     def refresh_service_status(self, ctxt, service_id):
     def refresh_service_status(self, ctxt, service_id):
         LOG.debug("Updating registration for worker service '%s'", service_id)
         LOG.debug("Updating registration for worker service '%s'", service_id)
         service = db_api.get_service(ctxt, service_id)
         service = db_api.get_service(ctxt, service_id)
-        worker_rpc = self._get_rpc_client_for_service(service)
+        worker_rpc = self._get_worker_rpc_for_host(service['host'])
         status = worker_rpc.get_service_status(ctxt)
         status = worker_rpc.get_service_status(ctxt)
         updated_values = {
         updated_values = {
             "providers": status["providers"],
             "providers": status["providers"],

+ 6 - 6
coriolis/replica_cron/rpc/client.py

@@ -9,17 +9,17 @@ from coriolis import rpc
 VERSION = "1.0"
 VERSION = "1.0"
 
 
 
 
-class ReplicaCronClient(object):
+class ReplicaCronClient(rpc.BaseRPCClient):
     def __init__(self, topic=constants.REPLICA_CRON_MAIN_MESSAGING_TOPIC):
     def __init__(self, topic=constants.REPLICA_CRON_MAIN_MESSAGING_TOPIC):
         target = messaging.Target(
         target = messaging.Target(
             topic=topic, version=VERSION)
             topic=topic, version=VERSION)
-        self._client = rpc.get_client(target)
+        super(ReplicaCronClient, self).__init__(target)
 
 
     def register(self, ctxt, schedule):
     def register(self, ctxt, schedule):
-        self._client.call(ctxt, 'register', schedule=schedule)
+        self._call(ctxt, 'register', schedule=schedule)
 
 
     def unregister(self, ctxt, schedule):
     def unregister(self, ctxt, schedule):
-        self._client.call(ctxt, 'unregister', schedule=schedule)
-    
+        self._call(ctxt, 'unregister', schedule=schedule)
+
     def get_diagnostics(self, ctxt):
     def get_diagnostics(self, ctxt):
-        return self._client.call(ctxt, 'get_diagnostics')
+        return self._call(ctxt, 'get_diagnostics')

+ 66 - 14
coriolis/rpc.py

@@ -1,11 +1,16 @@
 # Copyright 2016 Cloudbase Solutions Srl
 # Copyright 2016 Cloudbase Solutions Srl
 # All Rights Reserved.
 # All Rights Reserved.
 
 
-from oslo_config import cfg
+import contextlib
+
 import oslo_messaging as messaging
 import oslo_messaging as messaging
+from oslo_config import cfg
+from oslo_log import log as logging
 
 
-from coriolis import context
 import coriolis.exception
 import coriolis.exception
+from coriolis import context
+from coriolis import utils
+
 
 
 rpc_opts = [
 rpc_opts = [
     cfg.StrOpt('messaging_transport_url',
     cfg.StrOpt('messaging_transport_url',
@@ -19,9 +24,10 @@ rpc_opts = [
 CONF = cfg.CONF
 CONF = cfg.CONF
 CONF.register_opts(rpc_opts)
 CONF.register_opts(rpc_opts)
 
 
+LOG = logging.getLogger(__name__)
+
 ALLOWED_EXMODS = [
 ALLOWED_EXMODS = [
-    coriolis.exception.__name__,
-]
+    coriolis.exception.__name__]
 
 
 
 
 class RequestContextSerializer(messaging.Serializer):
 class RequestContextSerializer(messaging.Serializer):
@@ -47,16 +53,9 @@ class RequestContextSerializer(messaging.Serializer):
 
 
 
 
 def _get_transport():
 def _get_transport():
-    return messaging.get_transport(cfg.CONF, CONF.messaging_transport_url,
-                                   allowed_remote_exmods=ALLOWED_EXMODS)
-
-
-def get_client(target, serializer=None, timeout=None):
-    serializer = RequestContextSerializer(serializer)
-    if timeout is None:
-        timeout = CONF.default_messaging_timeout
-    return messaging.RPCClient(
-        _get_transport(), target, serializer=serializer, timeout=timeout)
+    return messaging.get_transport(
+        cfg.CONF, CONF.messaging_transport_url,
+        allowed_remote_exmods=ALLOWED_EXMODS)
 
 
 
 
 def get_server(target, endpoints, serializer=None):
 def get_server(target, endpoints, serializer=None):
@@ -64,3 +63,56 @@ def get_server(target, endpoints, serializer=None):
     return messaging.get_rpc_server(_get_transport(), target, endpoints,
     return messaging.get_rpc_server(_get_transport(), target, endpoints,
                                     executor='eventlet',
                                     executor='eventlet',
                                     serializer=serializer)
                                     serializer=serializer)
+
+
+class BaseRPCClient(object):
+    """ Wrapper for 'oslo_messaging.RPCClient' which automatically
+    instantiates and cleans up transports for each call.
+    """
+
+    def __init__(self, target, timeout=None, serializer=None):
+        self._target = target
+        self._timeout = timeout
+        if self._timeout is None:
+            self._timeout = CONF.default_messaging_timeout
+        self._serializer = RequestContextSerializer(serializer)
+
+    def __repr__(self):
+        return "<RPCClient(target=%s, timeout=%s)>" % (
+            self._target, self._timeout)
+
+    @contextlib.contextmanager
+    def _rpc_messaging_client(self):
+        transport = None
+        try:
+            transport = _get_transport()
+            yield messaging.RPCClient(
+                transport, self._target, serializer=self._serializer,
+                timeout=self._timeout)
+        finally:
+            if transport:
+                try:
+                    transport.cleanup()
+                except (Exception, KeyboardInterrupt):
+                    LOG.warn(
+                        "Exception occurred while cleaning up transport for "
+                        "RPC client instance '%s'. Error was: %s",
+                        repr(self), utils.get_exception_details())
+
+    def _call(self, ctxt, method, **kwargs):
+        with self._rpc_messaging_client() as client:
+            return client.call(ctxt, method, **kwargs)
+
+    def _call_on_host(self, host, ctxt, method, **kwargs):
+        with self._rpc_messaging_client() as client:
+            cctxt = client.prepare(server=host)
+            return cctxt.call(ctxt, method, **kwargs)
+
+    def _cast(self, ctxt, method, **kwargs):
+        with self._rpc_messaging_client() as client:
+            client.cast(ctxt, method, **kwargs)
+
+    def _cast_for_host(self, host, ctxt, method, **kwargs):
+        with self._rpc_messaging_client() as client:
+            cctxt = client.prepare(server=host)
+            cctxt.cast(ctxt, method, **kwargs)

+ 5 - 4
coriolis/scheduler/rpc/client.py

@@ -18,19 +18,20 @@ CONF = cfg.CONF
 CONF.register_opts(scheduler_opts, 'scheduler')
 CONF.register_opts(scheduler_opts, 'scheduler')
 
 
 
 
-class SchedulerClient(object):
+class SchedulerClient(rpc.BaseRPCClient):
     def __init__(self, timeout=None):
     def __init__(self, timeout=None):
         target = messaging.Target(topic='coriolis_scheduler', version=VERSION)
         target = messaging.Target(topic='coriolis_scheduler', version=VERSION)
         if timeout is None:
         if timeout is None:
             timeout = CONF.scheduler.scheduler_rpc_timeout
             timeout = CONF.scheduler.scheduler_rpc_timeout
-        self._client = rpc.get_client(target, timeout=timeout)
+        super(SchedulerClient, self).__init__(
+            target, timeout=timeout)
 
 
     def get_diagnostics(self, ctxt):
     def get_diagnostics(self, ctxt):
-        return self._client.call(ctxt, 'get_diagnostics')
+        return self._call(ctxt, 'get_diagnostics')
 
 
     def get_workers_for_specs(
     def get_workers_for_specs(
             self, ctxt, provider_requirements=None,
             self, ctxt, provider_requirements=None,
             region_sets=None, enabled=None):
             region_sets=None, enabled=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_workers_for_specs', region_sets=region_sets,
             ctxt, 'get_workers_for_specs', region_sets=region_sets,
             enabled=enabled, provider_requirements=provider_requirements)
             enabled=enabled, provider_requirements=provider_requirements)

+ 33 - 28
coriolis/worker/rpc/client.py

@@ -19,32 +19,37 @@ CONF = cfg.CONF
 CONF.register_opts(worker_opts, 'worker')
 CONF.register_opts(worker_opts, 'worker')
 
 
 
 
-class WorkerClient(object):
+class WorkerClient(rpc.BaseRPCClient):
     def __init__(
     def __init__(
-            self, timeout=None, topic=constants.WORKER_MAIN_MESSAGING_TOPIC):
+            self, timeout=None, host=None,
+            base_worker_topic=constants.WORKER_MAIN_MESSAGING_TOPIC):
+        topic = base_worker_topic
+        if host is not None:
+            topic = constants.SERVICE_MESSAGING_TOPIC_FORMAT % ({
+                "main_topic": base_worker_topic,
+                "host": host})
         target = messaging.Target(topic=topic, version=VERSION)
         target = messaging.Target(topic=topic, version=VERSION)
         if timeout is None:
         if timeout is None:
             timeout = CONF.worker.worker_rpc_timeout
             timeout = CONF.worker.worker_rpc_timeout
-        self._client = rpc.get_client(target, timeout=timeout)
+        super(WorkerClient, self).__init__(
+            target, timeout=timeout)
 
 
-    def begin_task(self, ctxt, server, task_id, task_type, origin, destination,
+    def begin_task(self, ctxt, task_id, task_type, origin, destination,
                    instance, task_info):
                    instance, task_info):
-        cctxt = self._client.prepare(server=server)
-        cctxt.cast(
+        self._cast(
             ctxt, 'exec_task', task_id=task_id, task_type=task_type,
             ctxt, 'exec_task', task_id=task_id, task_type=task_type,
             origin=origin, destination=destination, instance=instance,
             origin=origin, destination=destination, instance=instance,
             task_info=task_info)
             task_info=task_info)
 
 
-    def cancel_task(self, ctxt, server, task_id, process_id, force):
-        # Needs to be executed on the same server
-        cctxt = self._client.prepare(server=server)
-        cctxt.call(ctxt, 'cancel_task', task_id=task_id, process_id=process_id,
-                   force=force)
+    def cancel_task(self, ctxt, task_id, process_id, force):
+        return self._call(
+            ctxt, 'cancel_task',
+            task_id=task_id, process_id=process_id, force=force)
 
 
     def get_endpoint_instances(self, ctxt, platform_name, connection_info,
     def get_endpoint_instances(self, ctxt, platform_name, connection_info,
                                source_environment, marker=None, limit=None,
                                source_environment, marker=None, limit=None,
                                instance_name_pattern=None):
                                instance_name_pattern=None):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_instances',
             ctxt, 'get_endpoint_instances',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info,
             connection_info=connection_info,
@@ -55,7 +60,7 @@ class WorkerClient(object):
 
 
     def get_endpoint_instance(self, ctxt, platform_name, connection_info,
     def get_endpoint_instance(self, ctxt, platform_name, connection_info,
                               source_environment, instance_name):
                               source_environment, instance_name):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_instance',
             ctxt, 'get_endpoint_instance',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info,
             connection_info=connection_info,
@@ -64,7 +69,7 @@ class WorkerClient(object):
 
 
     def get_endpoint_destination_options(
     def get_endpoint_destination_options(
             self, ctxt, platform_name, connection_info, env, option_names):
             self, ctxt, platform_name, connection_info, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_destination_options',
             ctxt, 'get_endpoint_destination_options',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info,
             connection_info=connection_info,
@@ -73,7 +78,7 @@ class WorkerClient(object):
 
 
     def get_endpoint_source_options(
     def get_endpoint_source_options(
             self, ctxt, platform_name, connection_info, env, option_names):
             self, ctxt, platform_name, connection_info, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_source_options',
             ctxt, 'get_endpoint_source_options',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info,
             connection_info=connection_info,
@@ -81,7 +86,7 @@ class WorkerClient(object):
             option_names=option_names)
             option_names=option_names)
 
 
     def get_endpoint_networks(self, ctxt, platform_name, connection_info, env):
     def get_endpoint_networks(self, ctxt, platform_name, connection_info, env):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_networks',
             ctxt, 'get_endpoint_networks',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info,
             connection_info=connection_info,
@@ -89,70 +94,70 @@ class WorkerClient(object):
 
 
     def validate_endpoint_connection(self, ctxt, platform_name,
     def validate_endpoint_connection(self, ctxt, platform_name,
                                      connection_info):
                                      connection_info):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_connection',
             ctxt, 'validate_endpoint_connection',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info)
             connection_info=connection_info)
 
 
     def validate_endpoint_target_environment(
     def validate_endpoint_target_environment(
             self, ctxt, platform_name, target_env):
             self, ctxt, platform_name, target_env):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_target_environment',
             ctxt, 'validate_endpoint_target_environment',
             platform_name=platform_name,
             platform_name=platform_name,
             target_env=target_env)
             target_env=target_env)
 
 
     def validate_endpoint_source_environment(
     def validate_endpoint_source_environment(
             self, ctxt, platform_name, source_env):
             self, ctxt, platform_name, source_env):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_source_environment',
             ctxt, 'validate_endpoint_source_environment',
             platform_name=platform_name,
             platform_name=platform_name,
             source_env=source_env)
             source_env=source_env)
 
 
     def get_endpoint_storage(self, ctxt, platform_name, connection_info, env):
     def get_endpoint_storage(self, ctxt, platform_name, connection_info, env):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_storage',
             ctxt, 'get_endpoint_storage',
             platform_name=platform_name,
             platform_name=platform_name,
             connection_info=connection_info,
             connection_info=connection_info,
             env=env)
             env=env)
 
 
     def get_available_providers(self, ctxt):
     def get_available_providers(self, ctxt):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_available_providers')
             ctxt, 'get_available_providers')
 
 
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
     def get_provider_schemas(self, ctxt, platform_name, provider_type):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_provider_schemas',
             ctxt, 'get_provider_schemas',
             platform_name=platform_name,
             platform_name=platform_name,
             provider_type=provider_type)
             provider_type=provider_type)
 
 
     def get_diagnostics(self, ctxt):
     def get_diagnostics(self, ctxt):
-        return self._client.call(ctxt, 'get_diagnostics')
+        return self._call(ctxt, 'get_diagnostics')
 
 
     def get_service_status(self, ctxt):
     def get_service_status(self, ctxt):
-        return self._client.call(ctxt, 'get_service_status')
+        return self._call(ctxt, 'get_service_status')
 
 
     def get_endpoint_source_minion_pool_options(
     def get_endpoint_source_minion_pool_options(
             self, ctxt, platform_name, connection_info, env, option_names):
             self, ctxt, platform_name, connection_info, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_source_minion_pool_options',
             ctxt, 'get_endpoint_source_minion_pool_options',
             platform_name=platform_name, connection_info=connection_info,
             platform_name=platform_name, connection_info=connection_info,
             env=env, option_names=option_names)
             env=env, option_names=option_names)
 
 
     def get_endpoint_destination_minion_pool_options(
     def get_endpoint_destination_minion_pool_options(
             self, ctxt, platform_name, connection_info, env, option_names):
             self, ctxt, platform_name, connection_info, env, option_names):
-        return self._client.call(
+        return self._call(
             ctxt, 'get_endpoint_destination_minion_pool_options',
             ctxt, 'get_endpoint_destination_minion_pool_options',
             platform_name=platform_name, connection_info=connection_info,
             platform_name=platform_name, connection_info=connection_info,
             env=env, option_names=option_names)
             env=env, option_names=option_names)
 
 
     def validate_endpoint_source_minion_pool_options(
     def validate_endpoint_source_minion_pool_options(
             self, ctxt, platform_name, pool_environment):
             self, ctxt, platform_name, pool_environment):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_source_minion_pool_options',
             ctxt, 'validate_endpoint_source_minion_pool_options',
             platform_name=platform_name, pool_environment=pool_environment)
             platform_name=platform_name, pool_environment=pool_environment)
 
 
     def validate_endpoint_destination_minion_pool_options(
     def validate_endpoint_destination_minion_pool_options(
             self, ctxt, platform_name, pool_environment):
             self, ctxt, platform_name, pool_environment):
-        return self._client.call(
+        return self._call(
             ctxt, 'validate_endpoint_destination_minion_pool_options',
             ctxt, 'validate_endpoint_destination_minion_pool_options',
             platform_name=platform_name, pool_environment=pool_environment)
             platform_name=platform_name, pool_environment=pool_environment)