Просмотр исходного кода

Add full coverage for the `scheduler/rpc/server.py`
module

Mihaela Balutoiu 2 лет назад
Родитель
Сommit
428a1df4ad
1 измененных файлов с 122 добавлено и 2 удалено
  1. 122 2
      coriolis/tests/scheduler/rpc/test_server.py

+ 122 - 2
coriolis/tests/scheduler/rpc/test_server.py

@@ -1,15 +1,19 @@
-# Copyright 2023 Cloudbase Solutions Srl
+# Copyright 2024 Cloudbase Solutions Srl
 # All Rights Reserved.
 # All Rights Reserved.
 
 
+import logging
 from unittest import mock
 from unittest import mock
 
 
 import ddt
 import ddt
 
 
+from coriolis import constants
+from coriolis.db import api as db_api
 from coriolis import exception
 from coriolis import exception
 from coriolis.scheduler.filters import trivial_filters
 from coriolis.scheduler.filters import trivial_filters
 from coriolis.scheduler.rpc import server
 from coriolis.scheduler.rpc import server
 from coriolis.tests import test_base
 from coriolis.tests import test_base
 from coriolis.tests import testutils
 from coriolis.tests import testutils
+from coriolis import utils
 
 
 
 
 @ddt.ddt
 @ddt.ddt
@@ -20,6 +24,122 @@ class SchedulerServerEndpointTestCase(test_base.CoriolisBaseTestCase):
         super(SchedulerServerEndpointTestCase, self).setUp()
         super(SchedulerServerEndpointTestCase, self).setUp()
         self.server = server.SchedulerServerEndpoint()
         self.server = server.SchedulerServerEndpoint()
 
 
+    @mock.patch.object(utils, "get_diagnostics_info")
+    def test_get_diagnostics(self, mock_get_diagnostics_info):
+        result = self.server.get_diagnostics(mock.sentinel.context)
+
+        mock_get_diagnostics_info.assert_called_once_with()
+        self.assertEqual(result, mock_get_diagnostics_info.return_value)
+
+    @mock.patch.object(trivial_filters, 'TopicFilter', autospec=True)
+    @mock.patch.object(db_api, 'get_services')
+    def test_get_all_worker_services(self, mock_get_services,
+                                     mock_topic_filter_cls):
+        mock_get_services.return_value = mock.sentinel.services
+
+        mock_topic_filter_cls.return_value.filter_services.return_value = \
+            mock.sentinel.filtered_services
+
+        result = self.server._get_all_worker_services(mock.sentinel.context)
+
+        mock_get_services.assert_called_once_with(mock.sentinel.context)
+        mock_topic_filter_cls.assert_called_once_with(
+            constants.WORKER_MAIN_MESSAGING_TOPIC)
+        mock_topic_filter_cls.return_value.filter_services.\
+            assert_called_once_with(mock.sentinel.services)
+
+        self.assertEqual(result, mock.sentinel.filtered_services)
+
+    @mock.patch.object(db_api, 'get_services')
+    def test_get_all_worker_services_no_services(self, mock_get_services):
+        mock_get_services.return_value = []
+
+        self.assertRaises(exception.NoWorkerServiceError,
+                          self.server._get_all_worker_services,
+                          mock.sentinel.context)
+
+        mock_get_services.assert_called_once_with(mock.sentinel.context)
+
+    def test_get_weighted_filtered_services_no_filters(self):
+        services = [mock.Mock(id=1), mock.Mock(id=2)]
+
+        with self.assertLogs('coriolis.scheduler.rpc.server',
+                             level=logging.WARN):
+            result = self.server._get_weighted_filtered_services(services,
+                                                                 None)
+        expected_result = [(services[0], 100), (services[1], 100)]
+        self.assertEqual(result, expected_result)
+
+    def test_get_weighted_filtered_services_with_filters_reject(self):
+        services = [mock.Mock(id=1), mock.Mock(id=2)]
+        filters = [mock.Mock(), mock.Mock()]
+        filters[0].rate_service.return_value = 50
+        filters[1].rate_service.return_value = 0
+
+        self.assertRaises(exception.NoSuitableWorkerServiceError,
+                          self.server._get_weighted_filtered_services,
+                          services, filters)
+
+    def test_get_weighted_filtered_services_with_filters_accept(self):
+        services = [mock.Mock(id=1), mock.Mock(id=2)]
+        filters = [mock.Mock(), mock.Mock()]
+        filters[0].rate_service.return_value = 50
+        filters[1].rate_service.return_value = 100
+
+        result = self.server._get_weighted_filtered_services(services,
+                                                             filters)
+        expected_result = [(services[0], 150), (services[1], 150)]
+        self.assertEqual(result, expected_result)
+
+    @mock.patch.object(db_api, 'get_regions')
+    def test__filter_regions_check_all_exist_false(self, mock_get_regions):
+        mock_get_regions.return_value = [
+            mock.Mock(id='region1', enabled=True),
+            mock.Mock(id='region2', enabled=True),
+        ]
+        region_ids = ['region1', 'region2']
+
+        result = self.server._filter_regions(None, region_ids,
+                                             check_all_exist=False)
+
+        self.assertEqual(result, mock_get_regions.return_value)
+
+    @mock.patch.object(db_api, 'get_regions')
+    def test__filter_regions_all_disabled(self, mock_get_regions):
+        mock_get_regions.return_value = [
+            mock.Mock(id='region1', enabled=False),
+            mock.Mock(id='region2', enabled=False),
+        ]
+        region_ids = ['region1', 'region2']
+
+        result = self.server._filter_regions(None, region_ids, enabled=False)
+
+        self.assertEqual(result, mock_get_regions.return_value)
+
+    @mock.patch.object(db_api, 'get_regions')
+    def test__filter_regions_some_enabled_some_disabled(self,
+                                                        mock_get_regions):
+        mock_get_regions.return_value = [
+            mock.Mock(id='region1', enabled=True),
+            mock.Mock(id='region2', enabled=False),
+        ]
+        region_ids = ['region1', 'region2']
+
+        result = self.server._filter_regions(None, region_ids)
+
+        self.assertEqual(result, [mock_get_regions.return_value[0]])
+
+    @mock.patch.object(db_api, 'get_regions')
+    def test__filter_regions_some_missing(self, mock_get_regions):
+        mock_get_regions.return_value = [
+            mock.Mock(id='region1', enabled=True),
+            mock.Mock(id='region2', enabled=True),
+        ]
+        region_ids = ['region1', 'region2', 'region3']
+
+        self.assertRaises(exception.RegionNotFound,
+                          self.server._filter_regions, None, region_ids)
+
     @mock.patch.object(trivial_filters, 'ProviderTypesFilter', autospec=True)
     @mock.patch.object(trivial_filters, 'ProviderTypesFilter', autospec=True)
     @mock.patch.object(trivial_filters, 'RegionsFilter', autospec=True)
     @mock.patch.object(trivial_filters, 'RegionsFilter', autospec=True)
     @mock.patch.object(trivial_filters, 'EnabledFilter', autospec=True)
     @mock.patch.object(trivial_filters, 'EnabledFilter', autospec=True)
@@ -52,7 +172,7 @@ class SchedulerServerEndpointTestCase(test_base.CoriolisBaseTestCase):
         provider_requirements = config.get("provider_requirements", None)
         provider_requirements = config.get("provider_requirements", None)
 
 
         # Convert the config dict to an object, skipping the providers
         # Convert the config dict to an object, skipping the providers
-        # providers is the only field used as dict in the code
+        # as it's the only field used as dict in the code
         config_obj = testutils.DictToObject(config, skip_attrs=["providers"])
         config_obj = testutils.DictToObject(config, skip_attrs=["providers"])
         mock_get_all_worker_services.return_value = (
         mock_get_all_worker_services.return_value = (
             config_obj.services_db or []
             config_obj.services_db or []