Просмотр исходного кода

cleanup: Refactor SSH client creation

SSH clients are being created in multiple places. We should add an util
function for it.
Claudiu Belu 2 недель назад
Родитель
Сommit
e6d2a40ab5

+ 2 - 5
coriolis/osmorphing/osmount/base.py

@@ -11,7 +11,6 @@ import re
 import uuid
 
 from oslo_log import log as logging
-import paramiko
 from six import with_metaclass
 
 from coriolis import exception
@@ -82,10 +81,8 @@ class BaseSSHOSMountTools(BaseOSMountTools):
         self._event_manager.progress_update(
             "Connecting through SSH to OSMorphing host on: %(ip)s:%(port)s" %
             ({"ip": ip, "port": port}))
-        ssh = paramiko.SSHClient()
-        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        ssh.connect(hostname=ip, port=port, username=username, pkey=pkey,
-                    password=password)
+        ssh = utils.connect_ssh(
+            ip, port, username, pkey=pkey, password=password)
         ssh.set_log_channel("paramiko.morpher.%s.%s" % (ip, port))
         self._ssh = ssh
 

+ 6 - 28
coriolis/providers/backup_writers.py

@@ -564,20 +564,9 @@ class SSHBackupWriter(BaseBackupWriter):
     def _connect_ssh(self):
         LOG.info("Connecting to SSH host: %(ip)s:%(port)s" %
                  {"ip": self._ip, "port": self._port})
-        ssh = paramiko.SSHClient()
-        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        try:
-            ssh.connect(
-                hostname=self._ip,
-                port=self._port,
-                username=self._username,
-                pkey=self._pkey,
-                password=self._password)
-        except (Exception, KeyboardInterrupt):
-            # No need to log the error as we just raise
-            ssh.close()
-            raise
-        return ssh
+        return utils.connect_ssh(
+            self._ip, self._port, self._username,
+            pkey=self._pkey, password=self._password)
 
 
 class HTTPBackupWriterImpl(BaseBackupWriterImpl):
@@ -957,20 +946,9 @@ class HTTPBackupWriterBootstrapper(object):
     def _connect_ssh(self):
         LOG.info("Connecting to SSH host: %(ip)s:%(port)s" %
                  {"ip": self._ip, "port": self._port})
-        ssh = paramiko.SSHClient()
-        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        try:
-            ssh.connect(
-                hostname=self._ip,
-                port=self._port,
-                username=self._username,
-                pkey=self._pkey,
-                password=self._password)
-        except (Exception, KeyboardInterrupt):
-            # No need to log the error as we just raise
-            ssh.close()
-            raise
-        return ssh
+        return utils.connect_ssh(
+            self._ip, self._port, self._username,
+            pkey=self._pkey, password=self._password)
 
     def _inject_dport_allow_rule(self, ssh):
         cmd = (

+ 4 - 12
coriolis/providers/replicator.py

@@ -496,18 +496,10 @@ class Replicator(object):
         """
         gets a paramiko SSH client
         """
-        try:
-            ssh = paramiko.SSHClient()
-            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-            try:
-                ssh.connect(**args)
-                return ssh
-            except Exception:
-                ssh.close()
-                raise
-        except paramiko.ssh_exception.SSHException as ex:
-            raise exception.CoriolisException(
-                "Failed to setup SSH client: %s" % str(ex)) from ex
+        return utils.connect_ssh(
+            args["hostname"], args["port"], args["username"],
+            pkey=args.get("pkey"), password=args.get("password"),
+            banner_timeout=args.get("banner_timeout"))
 
     def _parse_source_ssh_conn_info(self, conn_info):
         # if we get valid SSH connection info we can

+ 1 - 4
coriolis/tests/integration/providers/test_provider/imp.py

@@ -259,10 +259,7 @@ class TestImportProvider(
 # Helpers
 def _ssh_connect(pkey_path):
     pkey = paramiko.RSAKey.from_private_key_file(pkey_path)
-    ssh = paramiko.SSHClient()
-    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-    ssh.connect(hostname="127.0.0.1", username="root", pkey=pkey)
-    return ssh
+    return utils.connect_ssh("127.0.0.1", 22, "root", pkey=pkey)
 
 
 def _read_file(path):

+ 21 - 3
coriolis/tests/providers/test_replicator.py

@@ -673,12 +673,21 @@ class ReplicatorTestCase(test_base.CoriolisBaseTestCase):
         original_get_ssh_client = testutils.get_wrapped_function(
             self.replicator._get_ssh_client)
 
-        result = original_get_ssh_client(self.replicator, self.conn_info)
+        arg = {
+            "hostname": self.conn_info["ip"],
+            "port": self.conn_info["port"],
+            "username": self.conn_info["username"],
+            "password": self.conn_info["password"],
+            "pkey": None,
+            "banner_timeout": (
+                replicator_module.CONF.replicator.default_requests_timeout),
+        }
+        result = original_get_ssh_client(self.replicator, arg)
 
         mock_ssh_client.assert_called_once()
         self._ssh.set_missing_host_key_policy.assert_called_once_with(
             mock.ANY)
-        self._ssh.connect.assert_called_once_with(**self.conn_info)
+        self._ssh.connect.assert_called_once_with(**arg)
 
         self.assertEqual(result, mock_ssh_client.return_value)
 
@@ -691,8 +700,17 @@ class ReplicatorTestCase(test_base.CoriolisBaseTestCase):
         original_get_ssh_client = testutils.get_wrapped_function(
             self.replicator._get_ssh_client)
 
+        arg = {
+            "hostname": self.conn_info["ip"],
+            "port": self.conn_info["port"],
+            "username": self.conn_info["username"],
+            "password": self.conn_info["password"],
+            "pkey": None,
+            "banner_timeout": (
+                replicator_module.CONF.replicator.default_requests_timeout),
+        }
         self.assertRaises(exception.CoriolisException, original_get_ssh_client,
-                          self.replicator, self.conn_info)
+                          self.replicator, arg)
 
     def test__parse_source_ssh_conn_info(self):
         expected_arg = {

+ 33 - 0
coriolis/utils.py

@@ -523,6 +523,39 @@ def deserialize_key(key_bytes, password=None):
     return paramiko.RSAKey.from_private_key(key_io, password)
 
 
+def connect_ssh(hostname, port, username, pkey=None, password=None,
+                connect_timeout=None, banner_timeout=None):
+    """Open and return a connected paramiko SSHClient.
+
+    :param pkey: a paramiko.PKey instance or None.
+    :param password: plaintext password or None.
+    :param connect_timeout: socket-level timeout in seconds (None = default).
+    :param banner_timeout: banner timeout in seconds passed to paramiko.
+    :raises: exception.CoriolisException on failure.
+    """
+    ssh = paramiko.SSHClient()
+    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+    kwargs = dict(
+        hostname=hostname, port=port, username=username,
+        pkey=pkey, password=password)
+
+    if connect_timeout is not None:
+        kwargs["timeout"] = connect_timeout
+    if banner_timeout is not None:
+        kwargs["banner_timeout"] = banner_timeout
+
+    try:
+        ssh.connect(**kwargs)
+    except paramiko.ssh_exception.SSHException as ex:
+        raise exception.CoriolisException(
+            "Failed to setup SSH client: %s" % str(ex)) from ex
+    except (Exception, KeyboardInterrupt):
+        ssh.close()
+        raise
+
+    return ssh
+
+
 def is_serializable(obj):
     pickle.dumps(obj)