ソースを参照

Complete worker RPC server tests

Daniel Vincze 2 年 前
コミット
1209a5694e

+ 38 - 0
coriolis/tests/worker/rpc/data/get_custom_ld_path.yml

@@ -0,0 +1,38 @@
+- config:
+    extra_library_paths: ""
+    exception_expected: true
+
+- config:
+    extra_library_paths: ~
+    exception_expected: true
+
+- config:
+    extra_library_paths: 12
+    exception_expected: true
+
+- config:
+    extra_library_paths: {}
+    exception_expected: true
+
+- config:
+    extra_library_paths:
+      - 1
+      - 2
+    exception_expected: true
+
+- config:
+    extra_library_paths:
+      - "path1"
+      - "path2"
+    expected_result: "path1:path2"
+
+- config:
+    original_ld_path: "original_path"
+    expected_result: "original_path:"
+
+- config:
+    original_ld_path: "original_path"
+    extra_library_paths:
+      - "path1"
+      - "path2"
+    expected_result: "original_path:path1:path2"

+ 301 - 1
coriolis/tests/worker/rpc/test_server.py

@@ -1,19 +1,27 @@
 # Copyright 2022 Cloudbase Solutions Srl
 # All Rights Reserved.
-
 import multiprocessing
 import os
+import shutil
 import signal
+import tempfile
 from unittest import mock
 
 import ddt
 import eventlet
+from oslo_log import log as logging
 import psutil
+from six.moves import queue
 
+from coriolis.conductor.rpc import client as conductor_client
+from coriolis.conductor.rpc import utils as cond_rpc_utils
 from coriolis import constants
+from coriolis import context
 from coriolis import exception
+from coriolis.minion_manager.rpc import client as minion_client
 from coriolis.providers import factory as providers_factory
 from coriolis import schemas
+from coriolis.tasks import factory as task_runners_factory
 from coriolis.tests import test_base
 from coriolis import utils
 from coriolis.worker.rpc import server
@@ -29,6 +37,96 @@ class WorkerServerEndpointTestCase(test_base.CoriolisBaseTestCase):
         super(WorkerServerEndpointTestCase, self).setUp()
         self.server = server.WorkerServerEndpoint()
 
+    @mock.patch.object(minion_client, 'MinionManagerPoolRpcEventHandler')
+    def test__get_event_handler_for_task_type_minion(
+            self, mock_minion_event_handler):
+        result = server._get_event_handler_for_task_type(
+            constants.TASK_TYPE_VALIDATE_SOURCE_MINION_POOL_OPTIONS,
+            mock.sentinel.ctxt,
+            mock.sentinel.task_object_id)
+        mock_minion_event_handler.assert_called_once_with(
+            mock.sentinel.ctxt, mock.sentinel.task_object_id)
+        self.assertEqual(result, mock_minion_event_handler.return_value)
+
+    @mock.patch.object(conductor_client, 'ConductorTaskRpcEventHandler')
+    def test__get_event_handler_for_task_type(
+            self, mock_conductor_event_handler):
+        result = server._get_event_handler_for_task_type(
+            constants.TASK_TYPE_REPLICATE_DISKS,
+            mock.sentinel.ctxt,
+            mock.sentinel.task_object_id)
+        mock_conductor_event_handler.assert_called_once_with(
+            mock.sentinel.ctxt, mock.sentinel.task_object_id)
+        self.assertEqual(result, mock_conductor_event_handler.return_value)
+
+    @mock.patch.object(conductor_client, 'ConductorClient')
+    def test__rpc_conductor_client(self, mock_cond_client):
+        result = self.server._rpc_conductor_client
+        mock_cond_client.assert_called_once()
+        self.assertEqual(result, mock_cond_client.return_value)
+
+    @mock.patch.object(conductor_client, 'ConductorClient')
+    def test__rpc_conductor_client_instantiated(self, mock_cond_client):
+        self.server._rpc_conductor_client_instance = mock.sentinel.cond_client
+        mock_cond_client.assert_not_called()
+
+    @mock.patch.object(conductor_client, 'ConductorClient')
+    @mock.patch.object(cond_rpc_utils, 'check_create_registration_for_service')
+    @mock.patch.object(server.WorkerServerEndpoint, 'get_service_status')
+    @mock.patch.object(context, 'RequestContext')
+    @mock.patch.object(utils, 'get_binary_name')
+    @mock.patch.object(utils, 'get_hostname')
+    def test__register_worker_service(
+            self, mock_hostname, mock_binary, mock_context,
+            mock_get_service_status, mock_check_create_service,
+            mock_cond_client):
+        result = self.server._register_worker_service()
+
+        mock_hostname.assert_called_once()
+        mock_binary.assert_called_once()
+        mock_context.assert_called_once_with('coriolis', 'admin')
+        mock_get_service_status.assert_called_once_with(
+            mock_context.return_value)
+        mock_check_create_service.assert_called_once_with(
+            mock_cond_client.return_value, mock_context.return_value,
+            mock_hostname.return_value, mock_binary.return_value,
+            constants.WORKER_MAIN_MESSAGING_TOPIC, enabled=True,
+            providers=mock_get_service_status.return_value['providers'],
+            specs=mock_get_service_status.return_value['specs'])
+
+        self.assertEqual(result, mock_check_create_service.return_value)
+        self.assertEqual(result, self.server._service_registration)
+
+    def test__check_remove_dir(self):
+        tmp = tempfile.mkdtemp()
+        self.server._check_remove_dir(tmp)
+        self.assertFalse(os.path.exists(tmp))
+
+    @mock.patch.object(shutil, 'rmtree')
+    def test__check_remove_dir_fails(self, mock_rmtree):
+        tmp = tempfile.mkdtemp()
+        mock_rmtree.side_effect = Exception('YOLO')
+        self.server._check_remove_dir(tmp)
+        self.assertLogs(server.LOG, level=logging.ERROR)
+        os.rmdir(tmp)
+
+    @mock.patch.object(server.WorkerServerEndpoint, 'get_available_providers')
+    @mock.patch.object(server.WorkerServerEndpoint, 'get_diagnostics')
+    def test_get_service_status(self, mock_get_diagnostics,
+                                mock_get_available_providers):
+        expected_result = {
+            "host": mock_get_diagnostics.return_value['hostname'],
+            "binary": mock_get_diagnostics.return_value['application'],
+            "topic": constants.WORKER_MAIN_MESSAGING_TOPIC,
+            "providers": mock_get_available_providers.return_value,
+            "specs": mock_get_diagnostics.return_value,
+        }
+        result = self.server.get_service_status(mock.sentinel.ctxt)
+        mock_get_available_providers.assert_called_once_with(
+            mock.sentinel.ctxt)
+        mock_get_diagnostics.assert_called_once()
+        self.assertEqual(result, expected_result)
+
     @mock.patch.object(server.WorkerServerEndpoint,
                        "_start_process_with_custom_library_paths")
     @mock.patch.object(server, "_task_process")
@@ -194,6 +292,137 @@ class WorkerServerEndpointTestCase(test_base.CoriolisBaseTestCase):
             )
             mock_client.confirm_task_cancellation.assert_called_once()
 
+    @mock.patch.object(logging, 'getLogger')
+    def test__handle_mp_log_events(self, mock_get_logger):
+        mock_mp_log_q = mock.MagicMock()
+        mock_p = mock.MagicMock()
+        mock_mp_log_q.get.side_effect = [
+            mock.sentinel.record, queue.Empty, None]
+        mock_p.is_alive.return_value = True
+
+        result = self.server._handle_mp_log_events(mock_p, mock_mp_log_q)
+        mock_get_logger.assert_called_once_with(mock.sentinel.record.name)
+        mock_get_logger.return_value.logger.handle.assert_called_once_with(
+            mock.sentinel.record)
+        self.assertIsNone(result)
+
+    def test__handle_mp_log_events_dead_process(self):
+        mock_mp_log_q = mock.MagicMock()
+        mock_p = mock.MagicMock()
+        mock_mp_log_q.get.side_effect = queue.Empty
+        mock_p.is_alive.return_value = False
+
+        result = self.server._handle_mp_log_events(mock_p, mock_mp_log_q)
+        self.assertIsNone(result)
+
+    @ddt.file_data('data/get_custom_ld_path.yml')
+    @ddt.unpack
+    def test__get_custom_ld_path(self, config):
+        original_ld_path = config.get("original_ld_path", "")
+        extra_library_paths = config.get("extra_library_paths", [])
+        exception_expected = config.get("exception_expected", False)
+        expected_result = config.get("expected_result")
+
+        if exception_expected:
+            self.assertRaises(
+                TypeError, self.server._get_custom_ld_path,
+                original_ld_path, extra_library_paths)
+            return
+
+        result = self.server._get_custom_ld_path(
+            original_ld_path, extra_library_paths)
+        self.assertEqual(result, expected_result)
+
+    @mock.patch.object(server.WorkerServerEndpoint, '_get_custom_ld_path')
+    def test__start_process_with_custom_library_paths(
+            self, mock_get_custom_ld_path):
+        original_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
+        # NOTE(dvincze): Return value is required to be string here, as this
+        # value will be assigned to environment variable LD_LIBRARY_PATH
+        mock_get_custom_ld_path.return_value = "custom_ld_path"
+        process = mock.MagicMock()
+
+        self.server._start_process_with_custom_library_paths(
+            process, mock.sentinel.extra_library_paths)
+        process.start.assert_called_once()
+        mock_get_custom_ld_path.assert_called_once_with(
+            original_ld_path, mock.sentinel.extra_library_paths)
+        self.assertEqual(original_ld_path, os.environ['LD_LIBRARY_PATH'])
+
+    @mock.patch.object(server.WorkerServerEndpoint, '_get_custom_ld_path')
+    def test__start_process_with_custom_library_paths_raises(
+            self, mock_custom_ld_path):
+        original_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
+        mock_custom_ld_path.side_effect = TypeError()
+        process = mock.MagicMock()
+
+        self.server._start_process_with_custom_library_paths(
+            process, mock.sentinel.extra_library_paths)
+        process.assert_not_called()
+        self.assertLogs(server.LOG, logging.WARNING)
+        self.assertEqual(original_ld_path, os.environ['LD_LIBRARY_PATH'])
+
+    @mock.patch.object(task_runners_factory, 'get_task_runner_class')
+    @mock.patch.object(server, '_get_event_handler_for_task_type')
+    def test__get_extra_library_paths_for_providers(
+            self, mock_get_event_handler, mock_get_task_runner):
+        result = self.server._get_extra_library_paths_for_providers(
+            mock.sentinel.ctxt, mock.sentinel.task_id, mock.sentinel.task_type,
+            mock.sentinel.origin, mock.sentinel.destination)
+        mock_get_event_handler.assert_called_once_with(mock.sentinel.task_type,
+                                                       mock.sentinel.ctxt,
+                                                       mock.sentinel.task_id)
+        mock_get_task_runner.assert_called_once_with(mock.sentinel.task_type)
+        mock_task_runner = (
+            mock_get_task_runner.return_value.return_value)
+        mock_task_runner.get_shared_libs_for_providers.assert_called_once_with(
+            mock.sentinel.ctxt, mock.sentinel.origin,
+            mock.sentinel.destination, mock_get_event_handler.return_value)
+        self.assertEqual(
+            result,
+            mock_task_runner.get_shared_libs_for_providers.return_value)
+
+    def test__wait_for_process(self):
+        p = mock.MagicMock()
+        mp_q = mock.MagicMock()
+
+        mp_q.get.side_effect = ["result", "result"]
+        p.is_alive.side_effect = [True, False]
+
+        result = self.server._wait_for_process(p, mp_q)
+
+        self.assertEqual(result, "result")
+
+    def test__wait_for_process_empty_queue_dead_process(self):
+        p = mock.MagicMock()
+        mp_q = mock.MagicMock()
+
+        mp_q.get.side_effect = [queue.Empty]
+        p.is_alive.side_effect = [False]
+
+        result = self.server._wait_for_process(p, mp_q)
+        self.assertEqual(result, None)
+
+    def test__wait_for_process_dead_process(self):
+        p = mock.MagicMock()
+        mp_q = mock.MagicMock()
+
+        mp_q.get.side_effect = [None, "result"]
+        p.is_alive.side_effect = [False]
+
+        result = self.server._wait_for_process(p, mp_q)
+        self.assertEqual(result, "result")
+
+    def test__wait_for_process_dead_process_quere_raise(self):
+        p = mock.MagicMock()
+        mp_q = mock.MagicMock()
+
+        mp_q.get.side_effect = [None, Exception]
+        p.is_alive.side_effect = [False]
+
+        result = self.server._wait_for_process(p, mp_q)
+        self.assertEqual(result, None)
+
     @mock.patch.object(server.WorkerServerEndpoint, "_exec_task_process")
     @mock.patch.object(server.WorkerServerEndpoint, "_rpc_conductor_client")
     @mock.patch.object(utils, "sanitize_task_info")
@@ -941,3 +1170,74 @@ class WorkerServerEndpointTestCase(test_base.CoriolisBaseTestCase):
     def test_get_diagnostics(self, mock_get_diagnostics_info):
         result = self.server.get_diagnostics(mock.sentinel.context)
         self.assertEqual(result, mock_get_diagnostics_info.return_value)
+
+    @mock.patch('coriolis.service.get_worker_count_from_args')
+    @mock.patch('sys.argv')
+    @mock.patch('oslo_config.cfg.CONF')
+    @mock.patch('coriolis.utils.setup_logging')
+    @mock.patch('oslo_log.log.getLogger')
+    @mock.patch('logging.handlers.QueueHandler')
+    def test__setup_task_process(self, mock_queue_handler, mock_get_logger,
+                                 mock_setup_logging, mock_conf, mock_argv,
+                                 mock_get_worker_count_from_args):
+        mock_get_worker_count_from_args.return_value = (None, "args")
+        mock_logger = mock_get_logger.return_value.logger
+        mock_logger.handlers = [mock.sentinel.handler]
+        server._setup_task_process(mock.sentinel.mp_log_q)
+
+        mock_get_worker_count_from_args.assert_called_once_with(mock_argv)
+        mock_conf.assert_called_once_with(
+            mock_get_worker_count_from_args.return_value[1][1:],
+            project='coriolis', version='1.0.0')
+        mock_setup_logging.assert_called_once_with()
+        mock_get_logger.assert_called_once_with(None)
+        mock_logger.removeHandler.assert_called_once_with(
+            mock.sentinel.handler)
+        mock_queue_handler.assert_called_once_with(
+            mock.sentinel.mp_log_q)
+        mock_logger.addHandler.assert_called_once_with(
+            mock_queue_handler.return_value)
+
+    @mock.patch.object(server, '_setup_task_process')
+    @mock.patch.object(task_runners_factory, 'get_task_runner_class')
+    @mock.patch.object(server, '_get_event_handler_for_task_type')
+    @mock.patch('coriolis.utils.is_serializable')
+    def test__task_process(self, mock_is_serializable,
+                           mock_get_event_handler, mock_get_task_runner_class,
+                           mock_setup_task_process):
+        mp_q = mock.MagicMock()
+        mp_log_q = mock.MagicMock()
+        task_info = {}
+        mock_task_runner = mock_get_task_runner_class.return_value.return_value
+        mock_task_result = mock_task_runner.run.return_value
+
+        server._task_process(mock.sentinel.ctxt, mock.sentinel.task_id,
+                             mock.sentinel.task_type, mock.sentinel.origin,
+                             mock.sentinel.destination, mock.sentinel.instance,
+                             task_info, mp_q, mp_log_q)
+        mock_setup_task_process.assert_called_once_with(mp_log_q)
+        mock_get_task_runner_class.assert_called_once_with(
+            mock.sentinel.task_type)
+        mock_get_event_handler.assert_called_once_with(mock.sentinel.task_type,
+                                                       mock.sentinel.ctxt,
+                                                       mock.sentinel.task_id)
+        mock_task_runner.run.assert_called_once_with(
+            mock.sentinel.ctxt, mock.sentinel.instance, mock.sentinel.origin,
+            mock.sentinel.destination, task_info,
+            mock_get_event_handler.return_value)
+        mock_is_serializable.assert_called_once_with(mock_task_result)
+        mp_q.put.assert_called_once_with(mock_task_result)
+        mp_log_q.put.assert_called_once_with(None)
+
+    @mock.patch.object(server, '_setup_task_process')
+    def test__task_process_raise(self, mock_setup_task_process):
+        mock_setup_task_process.side_effect = Exception('YOLO')
+        mp_q = mock.MagicMock()
+        mp_log_q = mock.MagicMock()
+
+        server._task_process(mock.sentinel.ctxt, mock.sentinel.task_id,
+                             mock.sentinel.task_type, mock.sentinel.origin,
+                             mock.sentinel.destination, mock.sentinel.instance,
+                             mock.sentinel.task_info, mp_q, mp_log_q)
+        mp_q.put.assert_called_once_with("YOLO")
+        mp_log_q.put.assert_called_once_with(None)

+ 18 - 12
coriolis/worker/rpc/server.py

@@ -1,15 +1,14 @@
 # Copyright 2016 Cloudbase Solutions Srl
 # All Rights Reserved.
 
+from logging import handlers
 import multiprocessing
-
-import eventlet
 import os
 import shutil
 import signal
 import sys
 
-from logging import handlers
+import eventlet
 from oslo_config import cfg
 from oslo_log import log as logging
 import psutil
@@ -27,7 +26,6 @@ from coriolis import service
 from coriolis.tasks import factory as task_runners_factory
 from coriolis import utils
 
-
 CONF = cfg.CONF
 CONF.register_opts([], 'worker')
 
@@ -146,6 +144,16 @@ class WorkerServerEndpoint(object):
                 if not p.is_alive():
                     break
 
+    def _get_custom_ld_path(self, original_ld_path, extra_library_paths):
+        if not isinstance(extra_library_paths, list):
+            raise TypeError("Passed extra_library_paths is not a list")
+
+        extra_libdirs = ":".join(extra_library_paths)
+        if not original_ld_path:
+            return extra_libdirs
+        else:
+            return "%s:%s" % (original_ld_path, extra_libdirs)
+
     def _start_process_with_custom_library_paths(
             self, process, extra_library_paths):
         """ Given a process instance, this method will add any shared libs
@@ -159,19 +167,17 @@ class WorkerServerEndpoint(object):
         libraries which should be available to the worker process.
         """
         original_ld_path = os.environ.get('LD_LIBRARY_PATH', "")
-        new_ld_path = None
-        extra_libdirs = ":".join(extra_library_paths)
-        if not original_ld_path:
-            new_ld_path = extra_libdirs
-        else:
-            new_ld_path = "%s:%s" % (original_ld_path, extra_libdirs)
-
         LOG.debug(
             "Starting new worker process with extra libraries: '%s'",
             extra_library_paths)
         try:
-            os.environ['LD_LIBRARY_PATH'] = new_ld_path
+            os.environ['LD_LIBRARY_PATH'] = self._get_custom_ld_path(
+                original_ld_path, extra_library_paths)
             process.start()
+        except TypeError:
+            LOG.warning(
+                "Failed to set extra library paths: %s. Error was: %s",
+                extra_library_paths, utils.get_exception_details())
         finally:
             os.environ['LD_LIBRARY_PATH'] = original_ld_path