Просмотр исходного кода

Merge pull request #112 from gabriel-samfira/refactor-backup-writer-interfaces

Refactor backup writer interfaces
Nashwan Azhari 6 лет назад
Родитель
Сommit
3c23124088
3 измененных файлов с 264 добавлено и 117 удалено
  1. 249 98
      coriolis/providers/backup_writers.py
  2. 8 5
      coriolis/tasks/base.py
  3. 7 14
      coriolis/tasks/replica_tasks.py

+ 249 - 98
coriolis/providers/backup_writers.py

@@ -10,6 +10,7 @@ import tempfile
 import threading
 import time
 import uuid
+import shutil
 
 import eventlet
 from oslo_config import cfg
@@ -34,6 +35,15 @@ CONF.register_opts(opts)
 _CORIOLIS_HTTP_WRITER_CMD = "coriolis-writer"
 
 LOG = logging.getLogger(__name__)
+BACKUP_WRITER_SSH = "ssh_backup_writer"
+BACKUP_WRITER_HTTP = "http_backup_writer"
+BACKUP_WRITER_FILE = "file_backup_writer"
+
+BACKUP_WRITERS = [
+    BACKUP_WRITER_SSH,
+    BACKUP_WRITER_HTTP,
+    BACKUP_WRITER_FILE
+]
 
 _WRITER_ERR_MAP = {
     -1: "ERR_MORE_MSG",
@@ -83,6 +93,44 @@ def _disable_lvm2_lvmetad(ssh):
             ssh, "sudo vgchange -an", get_pty=True)
 
 
+class BackupWritersFactory(object):
+
+    def __init__(self, writer_connection_info, volumes_info):
+        self._validate_info(writer_connection_info)
+        self._type = writer_connection_info["backend"]
+        self._conn_info = writer_connection_info["connection_details"]
+        self._volumes_info = volumes_info
+
+    def get_writer(self):
+        if self._type == BACKUP_WRITER_SSH:
+            return SSHBackupWriter.from_connection_info(
+                self._conn_info, self._volumes_info)
+        elif self._type == BACKUP_WRITER_HTTP:
+            return HTTPBackupWriter.from_connection_info(
+                self._conn_info, self._volumes_info)
+        elif self._type == BACKUP_WRITER_FILE:
+            return FileBackupWriter.from_connection_info(
+                self._conn_info, self._volumes_info)
+        raise exception.CoriolisException(
+            "Invalid backup writer type: %s" % self._type)
+
+    def _validate_info(self, info):
+        if type(info) is not dict:
+            raise exception.CoriolisException(
+                "Invalid backup writer connection info.")
+        wrt_type = info.get("backend", None)
+        if wrt_type is None:
+            raise exception.CoriolisException(
+                "Missing backend name in connection info")
+        if wrt_type not in BACKUP_WRITERS:
+            raise exception.CoriolisException(
+                "Invalid backup writer type: %s" % wrt_type)
+        wrt_conn_info = info.get("connection_details")
+        if wrt_conn_info is None:
+            raise exception.CoriolisException(
+                "Missing credentials in connection info")
+
+
 class BaseBackupWriterImpl(with_metaclass(abc.ABCMeta)):
     def __init__(self, path, disk_id):
         self._path = path
@@ -132,6 +180,11 @@ class BaseBackupWriter(with_metaclass(abc.ABCMeta)):
             if impl:
                 impl.close()
 
+    @classmethod
+    @abc.abstractmethod
+    def from_connection_info(cls, info, volumes_info):
+        pass
+
 
 class FileBackupWriterImpl(BaseBackupWriterImpl):
     def __init__(self, path, disk_id):
@@ -161,6 +214,10 @@ class FileBackupWriter(BaseBackupWriter):
     def _get_impl(self, path, disk_id):
         return FileBackupWriterImpl(path, disk_id)
 
+    @classmethod
+    def from_connection_info(cls, info, volumes_info):
+        return cls()
+
 
 class SSHBackupWriterImpl(BaseBackupWriterImpl):
     def __init__(self, path, disk_id, compress_transfer=None,
@@ -356,6 +413,32 @@ class SSHBackupWriter(BaseBackupWriter):
         self._ssh = None
         self._lock = threading.Lock()
 
+    @classmethod
+    def from_connection_info(cls, info, volumes_info):
+        required = ["ip", "port", "username"]
+        ip = info.get("ip")
+        port = info.get("port")
+        username = info.get("username")
+        pkey = info.get("pkey")
+        password = info.get("password")
+
+        if not all([ip, port, username]):
+            raise exception.CoriolisException(
+                "Connection info is invalid for SSHBackupWriter. "
+                "The following fields are required: %s" % ", ".join(required))
+        if pkey is None and password is None:
+            raise exception.CoriolisException(
+                "Either pkey or password are required")
+
+        if pkey:
+            if type(pkey) is not str:
+                raise exception.CoriolisException(
+                    "pkey must be a PEM encoded RSA private key")
+            pkey = utils.deserialize_key(
+                pkey, CONF.serialization.temp_keypair_password)
+
+        return cls(ip, port, username, pkey, password, volumes_info)
+
     def _get_impl(self, path, disk_id):
         ssh = self._connect_ssh()
         _disable_lvm2_lvmetad(ssh)
@@ -626,38 +709,48 @@ class HTTPBackupWriterImpl(BaseBackupWriterImpl):
             self._compressor_evt = None
 
 
-class HTTPBackupWriter(BaseBackupWriter):
+class HTTPBackupWriterBoostrapper(object):
 
-    def __init__(self, ip, port, username, pkey,
-                 password, writer_port, volumes_info,
-                 cert_dir, compressor_count=3):
-        self._ip = ip
-        self._port = port
-        self._username = username
-        self._pkey = pkey
-        self._password = password
-        self._volumes_info = volumes_info
-        self._writer_port = writer_port
+    def __init__(self, ssh_conn_info, writer_port):
         self._lock = threading.Lock()
-        self._id = str(uuid.uuid4())
-        self._compressor_count = compressor_count
         self._writer_cmd = os.path.join(
             "/usr/bin", _CORIOLIS_HTTP_WRITER_CMD)
-        self._crt = None
-        self._key = None
-        self._ca = None
-        if os.path.isdir(cert_dir) is False:
+        self._writer_port = writer_port
+        self._ip = ssh_conn_info.get("ip")
+        self._port = ssh_conn_info.get("port", 22)
+        self._username = ssh_conn_info.get("username")
+        self._password = ssh_conn_info.get("password")
+        self._pkey = ssh_conn_info.get("pkey")
+        if not all([self._ip, self._port, self._username]):
             raise exception.CoriolisException(
-                "Certificates dir %s does not exist" % cert_dir
-            )
-        self._crt_dir = cert_dir
+                "Invalid SSH connection info. IP, port and"
+                " username are mandatory")
+        if self._password is None and self._pkey is None:
+            raise exception.CoriolisException(
+                "Either password or pkey are required")
+        if self._pkey:
+            self._pkey = utils.deserialize_key(
+                self._pkey, CONF.serialization.temp_keypair_password)
+        self._ssh = self._connect_ssh()
 
-    def _wait_for_conn(self):
-        LOG.debug(
-            "waiting for coriolis-writer connectivity %s:%s" % (
-                self._ip, self._writer_port))
-        utils.wait_for_port_connectivity(
-            self._ip, self._writer_port)
+    @utils.retry_on_error(sleep_seconds=30)
+    def _connect_ssh(self):
+        LOG.info("Connecting to SSH host: %(ip)s:%(port)s" %
+                 {"ip": self._ip, "port": self._port})
+        ssh = paramiko.SSHClient()
+        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+        try:
+            ssh.connect(
+                hostname=self._ip,
+                port=self._port,
+                username=self._username,
+                pkey=self._pkey,
+                password=self._password)
+        except (Exception, KeyboardInterrupt):
+            # No need to log the error as we just raise
+            ssh.close()
+            raise
+        return ssh
 
     def _inject_iptables_allow(self, ssh):
         utils.exec_ssh_cmd(
@@ -665,25 +758,6 @@ class HTTPBackupWriter(BaseBackupWriter):
             "sudo /sbin/iptables -I INPUT -p tcp --dport %s "
             "-j ACCEPT" % self._writer_port, get_pty=True)
 
-    def _get_impl(self, path, disk_id):
-        ssh = self._connect_ssh()
-        _disable_lvm2_lvmetad(ssh)
-        self._setup_writer(ssh)
-
-        path = [v for v in self._volumes_info
-                if v["disk_id"] == disk_id][0]["volume_dev"]
-        impl = HTTPBackupWriterImpl(
-            path, disk_id, self._compressor_count)
-        impl._set_info({
-            "ip": self._ip,
-            "port": self._writer_port,
-            "client_crt": self._crt,
-            "client_key": self._key,
-            "ca_crt": self._ca,
-            "id": self._id,
-        })
-        return impl
-
     @utils.retry_on_error()
     def _copy_writer(self, ssh):
         local_path = os.path.join(
@@ -737,16 +811,11 @@ class HTTPBackupWriter(BaseBackupWriter):
         remote_srv_crt = os.path.join(remote_base_dir, srv_crt_name)
         remote_srv_key = os.path.join(remote_base_dir, srv_key_name)
 
-        ca_crt = os.path.join(self._crt_dir, ca_crt_name)
-        client_crt = os.path.join(self._crt_dir, client_crt_name)
-        client_key = os.path.join(self._crt_dir, client_key_name)
-
         exist = []
         for i in (remote_ca_crt, remote_client_crt, remote_client_key,
                   remote_srv_crt, remote_srv_key):
             exist.append(utils.test_ssh_path(ssh, i))
 
-        force_fetch = False
         if not all(exist):
             utils.exec_ssh_cmd(
                 ssh, "sudo mkdir -p %s" % remote_base_dir, get_pty=True)
@@ -759,32 +828,20 @@ class HTTPBackupWriter(BaseBackupWriter):
                     "extra_hosts": self._ip,
                 },
                 get_pty=True)
-            force_fetch = True
-
-        exists = []
-        for i in (ca_crt, client_crt, client_key):
-            exists.append(os.path.isfile(i))
-
-        if not all(exists) or force_fetch:
-            # certificates either are missing, or have been regenerated
-            # on the writer worker. We need to fetch them.
-            self._fetch_remote_file(ssh, remote_ca_crt, ca_crt)
-            self._fetch_remote_file(ssh, remote_client_crt, client_crt)
-            self._fetch_remote_file(ssh, remote_client_key, client_key)
 
         return {
-            "local": {
-                "client_crt": client_crt,
-                "client_key": client_key,
-                "ca_crt": ca_crt,
-            },
-            "remote": {
-                "srv_crt": remote_srv_crt,
-                "srv_key": remote_srv_key,
-                "ca_crt": remote_ca_crt,
-            },
+            "srv_crt": remote_srv_crt,
+            "srv_key": remote_srv_key,
+            "ca_crt": remote_ca_crt,
+            "client_crt": remote_client_crt,
+            "client_key": remote_client_key
         }
 
+    def _read_remote_file_sudo(self, remote_path):
+        contents = utils.exec_ssh_cmd(
+            self._ssh, "sudo cat %s" % remote_path, get_pty=True)
+        return contents.decode()
+
     def _init_writer(self, ssh, cert_paths):
         cmdline = ("%(cmd)s run -ca-cert %(ca_cert)s -key "
                    "%(srv_key)s -cert %(srv_cert)s -listen-port "
@@ -798,33 +855,127 @@ class HTTPBackupWriter(BaseBackupWriter):
         utils.create_service(
             ssh, cmdline, _CORIOLIS_HTTP_WRITER_CMD, start=True)
         self._inject_iptables_allow(ssh)
-        self._wait_for_conn()
 
-    def _setup_writer(self, ssh):
-        self._copy_writer(ssh)
+    def setup_writer(self):
+        _disable_lvm2_lvmetad(self._ssh)
+        self._copy_writer(self._ssh)
         paths = utils.retry_on_error()(
-            self._setup_certificates)(ssh)
-        self._crt = paths["local"]["client_crt"]
-        self._key = paths["local"]["client_key"]
-        self._ca = paths["local"]["ca_crt"]
+            self._setup_certificates)(self._ssh)
         utils.retry_on_error()(
-            self._init_writer)(ssh, paths["remote"])
+            self._init_writer)(self._ssh, paths)
+        return {
+            "ip": self._ip,
+            "port": self._writer_port,
+            "certificates": {
+                "client_crt": self._read_remote_file_sudo(paths["client_crt"]),
+                "client_key": self._read_remote_file_sudo(paths["client_key"]),
+                "ca_crt": self._read_remote_file_sudo(paths["ca_crt"])
+            }
+        }
 
-    @utils.retry_on_error(sleep_seconds=30)
-    def _connect_ssh(self):
-        LOG.info("Connecting to SSH host: %(ip)s:%(port)s" %
-                 {"ip": self._ip, "port": self._port})
-        ssh = paramiko.SSHClient()
-        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        try:
-            ssh.connect(
-                hostname=self._ip,
-                port=self._port,
-                username=self._username,
-                pkey=self._pkey,
-                password=self._password)
-        except (Exception, KeyboardInterrupt):
-            # No need to log the error as we just raise
-            ssh.close()
-            raise
-        return ssh
+
+class HTTPBackupWriter(BaseBackupWriter):
+
+    def __init__(self, ip, port, volumes_info, certificates,
+                 compressor_count=3):
+        self._ip = ip
+        self._port = port
+        self._volumes_info = volumes_info
+        self._writer_port = port
+        self._id = str(uuid.uuid4())
+        self._compressor_count = compressor_count
+
+        self._certificates = certificates
+        self._crt_dir = tempfile.mkdtemp()
+        if not self._certificates:
+            raise exception.CoriolisException(
+                "certificates is mandatory")
+        self._cert_paths = None
+
+    @classmethod
+    def from_connection_info(cls, conn_info, volumes_info):
+        """Instantiate a HTTP backup writer from connection info.
+
+        Connection info has the following schema:
+
+        {
+            # IP address or hostname where we can reach the backup writer
+            "ip": "192.168.0.1",
+            # Backup writer port
+            "port": 4433,
+            "certificates": {
+                # PEM encoded client certificate
+                "client_crt": "",
+                # PEM encoded client private key
+                "client_key": "",
+                # PEM encoded CA certificate we use to validate the server
+                "ca_crt": ""
+            }
+        }
+        """
+        ip = conn_info.get("ip")
+        port = conn_info.get("port")
+        certs = conn_info.get("certificates")
+
+        required = ["ip", "port", "certificates"]
+        if not all([ip, port, certs]):
+            raise exception.CoriolisException(
+                "Missing required connection info: %s" % ", ".join(required))
+        return cls(ip, port, volumes_info, certs)
+
+    def __del__(self):
+        if self._crt_dir and os.path.isdir(self._crt_dir):
+            try:
+                shutil.rmtree(self._crt_dir)
+            except BaseException:
+                pass
+
+    def _wait_for_conn(self):
+        LOG.debug(
+            "waiting for coriolis-writer connectivity %s:%s" % (
+                self._ip, self._writer_port))
+        utils.wait_for_port_connectivity(
+            self._ip, self._writer_port)
+
+    def _write_cert_files(self):
+        if not self._certificates:
+            raise exception.CoriolisException(
+                "certificates not set")
+        if self._cert_paths:
+            return self._cert_paths
+
+        crt_file = tempfile.mkstemp(dir=self._crt_dir)[1]
+        key_file = tempfile.mkstemp(dir=self._crt_dir)[1]
+        ca_crt_file = tempfile.mkstemp(dir=self._crt_dir)[1]
+        with open(crt_file, "w") as fd:
+            fd.write(self._certificates["client_crt"])
+        with open(key_file, "w") as fd:
+            fd.write(self._certificates["client_key"])
+        with open(ca_crt_file, "w") as fd:
+            fd.write(self._certificates["ca_crt"])
+        self._cert_paths = {
+            "client_crt": crt_file,
+            "client_key": key_file,
+            "ca_crt": ca_crt_file,
+        }
+        return self._cert_paths
+
+    def _get_impl(self, path, disk_id):
+        cert_paths = self._write_cert_files()
+        self._wait_for_conn()
+
+        path = [v for v in self._volumes_info
+                if v["disk_id"] == disk_id][0]["volume_dev"]
+        impl = HTTPBackupWriterImpl(
+            path, disk_id,
+            compressor_count=self._compressor_count,
+            compress_transfer=CONF.compress_transfers)
+        impl._set_info({
+            "ip": self._ip,
+            "port": self._writer_port,
+            "client_crt": cert_paths["client_crt"],
+            "client_key": cert_paths["client_key"],
+            "ca_crt": cert_paths["ca_crt"],
+            "id": self._id,
+        })
+        return impl

+ 8 - 5
coriolis/tasks/base.py

@@ -2,6 +2,7 @@
 # All Rights Reserved.
 
 import abc
+import paramiko
 
 from oslo_config import cfg
 from oslo_log import log as logging
@@ -64,9 +65,10 @@ def get_connection_info(ctxt, data):
 def marshal_migr_conn_info(migr_connection_info):
     if migr_connection_info and "pkey" in migr_connection_info:
         migr_connection_info = migr_connection_info.copy()
-        migr_connection_info["pkey"] = utils.serialize_key(
-            migr_connection_info["pkey"],
-            CONF.serialization.temp_keypair_password)
+        pkey = migr_connection_info["pkey"]
+        if isinstance(pkey, str) is False:
+            migr_connection_info["pkey"] = utils.serialize_key(
+                pkey, CONF.serialization.temp_keypair_password)
     return migr_connection_info
 
 
@@ -74,6 +76,7 @@ def unmarshal_migr_conn_info(migr_connection_info):
     if migr_connection_info and "pkey" in migr_connection_info:
         migr_connection_info = migr_connection_info.copy()
         pkey_str = migr_connection_info["pkey"]
-        migr_connection_info["pkey"] = utils.deserialize_key(
-            pkey_str, CONF.serialization.temp_keypair_password)
+        if isinstance(pkey_str, paramiko.rsakey.RSAKey) is False:
+            migr_connection_info["pkey"] = utils.deserialize_key(
+                pkey_str, CONF.serialization.temp_keypair_password)
     return migr_connection_info

+ 7 - 14
coriolis/tasks/replica_tasks.py

@@ -7,6 +7,7 @@ from coriolis import constants
 from coriolis import events
 from coriolis import exception
 from coriolis.providers import factory as providers_factory
+from coriolis.providers import backup_writers
 from coriolis import schemas
 from coriolis.tasks import base
 from coriolis import utils
@@ -120,12 +121,6 @@ class ReplicateDisksTask(base.TaskRunner):
             migr_source_conn_info)
 
         migr_target_conn_info = task_info["migr_target_connection_info"]
-        schemas.validate_value(
-            migr_target_conn_info,
-            schemas.CORIOLIS_DISK_SYNC_RESOURCES_CONN_INFO_SCHEMA)
-        migr_target_conn_info = base.unmarshal_migr_conn_info(
-            migr_target_conn_info)
-
         incremental = task_info.get("incremental", True)
 
         source_environment = origin.get('source_environment') or {}
@@ -284,14 +279,12 @@ class DeployReplicaTargetResourcesTask(base.TaskRunner):
             "migr_resources"]
 
         migr_connection_info = replica_resources_info["connection_info"]
-        migr_connection_info = base.marshal_migr_conn_info(
-            migr_connection_info)
-        schemas.validate_value(
-            migr_connection_info,
-            schemas.CORIOLIS_DISK_SYNC_RESOURCES_CONN_INFO_SCHEMA,
-            # NOTE: we avoid raising so that the cleanup task
-            # can [try] to deal with the temporary resources.
-            raise_on_error=False)
+        try:
+            backup_writers.BackupWritersFactory(
+                migr_connection_info, None).get_writer()
+        except BaseException as err:
+            LOG.exception(
+                "Invalid connection info: %s" % err)
 
         task_info["migr_target_connection_info"] = migr_connection_info