Просмотр исходного кода

Merged in alexpilotti/coriolis/parallel_writers_readers (pull request #78)

Parallel writers and readers COR-34
Alessandro Pilotti 8 лет назад
Родитель
Сommit
b59bb00963
2 измененных файлов с 150 добавлено и 89 удалено
  1. 132 77
      coriolis/providers/backup_writers.py
  2. 18 12
      coriolis/qemu_reader.py

+ 132 - 77
coriolis/providers/backup_writers.py

@@ -3,7 +3,9 @@
 
 import abc
 import contextlib
+import errno
 import os
+import threading
 
 from oslo_log import log as logging
 import paramiko
@@ -15,20 +17,17 @@ from coriolis import utils
 LOG = logging.getLogger(__name__)
 
 
-class BaseBackupWriter(metaclass=abc.ABCMeta):
+class BaseBackupWriterImpl(metaclass=abc.ABCMeta):
+    def __init__(self, path, disk_id):
+        self._path = path
+        self._disk_id = disk_id
+
     @abc.abstractmethod
     def _open(self):
         pass
 
-    @contextlib.contextmanager
-    def open(self, path, disk_id):
-        self._path = path
-        self._disk_id = disk_id
-        self._open()
-        try:
-            yield self
-        finally:
-            self.close()
+    def _handle_exception(self, ex):
+        LOG.exception(ex)
 
     @abc.abstractmethod
     def seek(self, pos):
@@ -47,7 +46,32 @@ class BaseBackupWriter(metaclass=abc.ABCMeta):
         pass
 
 
-class FileBackupWriter(BaseBackupWriter):
+class BaseBackupWriter(metaclass=abc.ABCMeta):
+    @abc.abstractmethod
+    def _get_impl(self, path, disk_id):
+        pass
+
+    @contextlib.contextmanager
+    def open(self, path, disk_id):
+        impl = None
+        try:
+            impl = self._get_impl(path, disk_id)
+            impl._open()
+            yield impl
+        except Exception as ex:
+            if impl:
+                impl._handle_exception(ex)
+            raise
+        finally:
+            if impl:
+                impl.close()
+
+
+class FileBackupWriterImpl(BaseBackupWriterImpl):
+    def __init__(self, path, disk_id):
+        self._file = None
+        super(FileBackupWriterImpl, self).__init__(path, disk_id)
+
     def _open(self):
         # Create file if it doesnt exist
         open(self._path, 'ab+').close()
@@ -64,69 +88,26 @@ class FileBackupWriter(BaseBackupWriter):
 
     def close(self):
         self._file.close()
+        self._file = None
 
 
-class SSHBackupWriter(BaseBackupWriter):
-    def __init__(self, ip, port, username, pkey, password, volumes_info):
-        self._ip = ip
-        self._port = port
-        self._username = username
-        self._pkey = pkey
-        self._password = password
-        self._volumes_info = volumes_info
-        self._ssh = None
-
-    @contextlib.contextmanager
-    def open(self, path, disk_id):
-        self._path = path
-        self._disk_id = disk_id
-        self._open()
-        try:
-            yield self
-            # Don't send a message via ssh on exception
-            self.close()
-        except Exception as ex:
-            LOG.exception(ex)
-
-            ret_val = None
-            # if the application is still running on the other side,
-            # recv_exit_status() will block. Check that we have an
-            # exit status before retrieving it
-            if self._stdout.channel.exit_status_ready():
-                ret_val = self._stdout.channel.recv_exit_status()
-
-            self._ssh.close()
+class FileBackupWriter(BaseBackupWriter):
+    def _get_impl(self, path, disk_id):
+        return FileBackupWriterImpl(path, disk_id)
 
-            if ret_val:
-                # TODO(alexpilotti): map error codes to error messages
-                raise exception.CoriolisException(
-                    "An exception occurred while writing data on target. "
-                    "Exit code: %s" % ret_val)
-            else:
-                raise exception.CoriolisException(
-                    "An exception occurred while writing data on target: %s" %
-                    ex)
 
-    @utils.retry_on_error()
-    def _connect_ssh(self):
-        LOG.info("Connecting to SSH host: %(ip)s:%(port)s" %
-                 {"ip": self._ip, "port": self._port})
-        self._ssh = paramiko.SSHClient()
-        self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        self._ssh.connect(
-            hostname=self._ip,
-            port=self._port,
-            username=self._username,
-            pkey=self._pkey,
-            password=self._password)
+class SSHBackupWriterImpl(BaseBackupWriterImpl):
+    def __init__(self, path, disk_id):
+        self._msg_id = None
+        self._stdin = None
+        self._stdout = None
+        self._stderr = None
+        self._offset = None
+        self._ssh = None
+        super(SSHBackupWriterImpl, self).__init__(path, disk_id)
 
-    @utils.retry_on_error()
-    def _copy_helper_cmd(self):
-        sftp = self._ssh.open_sftp()
-        local_path = os.path.join(
-            utils.get_resources_dir(), 'write_data')
-        sftp.put(local_path, 'write_data')
-        sftp.close()
+    def _set_ssh_client(self, ssh):
+        self._ssh = ssh
 
     @utils.retry_on_error()
     def _exec_helper_cmd(self):
@@ -136,16 +117,15 @@ class SSHBackupWriter(BaseBackupWriter):
             "chmod +x write_data && sudo ./write_data")
 
     def _encode_data(self, content):
-        path = [v for v in self._volumes_info
-                if v["disk_id"] == self._disk_id][0]["volume_dev"]
-
         msg = data_transfer.encode_data(
-            self._msg_id, path, self._offset, content)
+            self._msg_id, self._path, self._offset, content)
 
         LOG.debug(
             "Guest path: %(path)s, offset: %(offset)d, content len: "
             "%(content_len)d, msg len: %(msg_len)d",
-            {"path": path, "offset": self._offset, "content_len": len(content),
+            {"path": self._path,
+             "offset": self._offset,
+             "content_len": len(content),
              "msg_len": len(msg)})
         return msg
 
@@ -162,8 +142,6 @@ class SSHBackupWriter(BaseBackupWriter):
         self._stdout.read(4)
 
     def _open(self):
-        self._connect_ssh()
-        self._copy_helper_cmd()
         self._exec_helper_cmd()
 
     def seek(self, pos):
@@ -177,5 +155,82 @@ class SSHBackupWriter(BaseBackupWriter):
         self._offset += len(data)
 
     def close(self):
-        self._send_msg(self._encode_eod())
+        if self._ssh:
+            self._send_msg(self._encode_eod())
+            self._ssh.close()
+            self._ssh = None
+
+    def _handle_exception(self, ex):
+        ret_val = None
+        # if the application is still running on the other side,
+        # recv_exit_status() will block. Check that we have an
+        # exit status before retrieving it
+        if self._stdout.channel.exit_status_ready():
+            ret_val = self._stdout.channel.recv_exit_status()
+
+        # Don't send a message via ssh on exception
         self._ssh.close()
+        self._ssh = None
+
+        if ret_val:
+            # TODO(alexpilotti): map error codes to error messages
+            raise exception.CoriolisException(
+                "An exception occurred while writing data on target. "
+                "Exit code: %s" % ret_val)
+        else:
+            raise exception.CoriolisException(
+                "An exception occurred while writing data on target: %s" %
+                ex)
+
+
+class SSHBackupWriter(BaseBackupWriter):
+    def __init__(self, ip, port, username, pkey, password, volumes_info):
+        self._ip = ip
+        self._port = port
+        self._username = username
+        self._pkey = pkey
+        self._password = password
+        self._volumes_info = volumes_info
+        self._ssh = None
+        self._lock = threading.Lock()
+
+    def _get_impl(self, path, disk_id):
+        ssh = self._connect_ssh()
+
+        path = [v for v in self._volumes_info
+                if v["disk_id"] == disk_id][0]["volume_dev"]
+        impl = SSHBackupWriterImpl(path, disk_id)
+
+        self._copy_helper_cmd(ssh)
+        impl._set_ssh_client(ssh)
+        return impl
+
+    @utils.retry_on_error()
+    def _copy_helper_cmd(self, ssh):
+        with self._lock:
+            sftp = ssh.open_sftp()
+            local_path = os.path.join(
+                utils.get_resources_dir(), 'write_data')
+            try:
+                # Check if the remote file already exists
+                sftp.stat('write_data')
+            except IOError as ex:
+                if ex.errno != errno.ENOENT:
+                    raise
+                sftp.put(local_path, 'write_data')
+            finally:
+                sftp.close()
+
+    @utils.retry_on_error()
+    def _connect_ssh(self):
+        LOG.info("Connecting to SSH host: %(ip)s:%(port)s" %
+                 {"ip": self._ip, "port": self._port})
+        ssh = paramiko.SSHClient()
+        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+        ssh.connect(
+            hostname=self._ip,
+            port=self._port,
+            username=self._username,
+            pkey=self._pkey,
+            password=self._password)
+        return ssh

+ 18 - 12
coriolis/qemu_reader.py

@@ -8,14 +8,15 @@ from coriolis import exception
 from coriolis import qemu
 
 
-class QEMUDiskImageReader(object):
-    def __init__(self):
+class QEMUDiskImageReaderImpl(object):
+    def __init__(self, path):
         self._blk = None
         self._bs = None
         self._total_sectors = None
         self._block_driver_state = None
         self._buf = None
         self._buf_size = None
+        self._path = path
 
     def close(self):
         if self._buf is not None:
@@ -31,12 +32,12 @@ class QEMUDiskImageReader(object):
         self._total_sectors = None
         self._block_driver_state = None
 
-    def _qemu_open_path(self, path):
+    def _open(self):
         error = ctypes.POINTER(qemu.Error)()
 
         options = qemu.qdict_new()
         blk = qemu.blk_new_open(
-            path.encode(), None, options, 0, ctypes.byref(error))
+            self._path.encode(), None, options, 0, ctypes.byref(error))
         if not blk:
             raise exception.QEMUException(error.msg)
 
@@ -49,14 +50,6 @@ class QEMUDiskImageReader(object):
     def disk_size(self):
         return self._total_sectors << qemu.BDRV_SECTOR_BITS
 
-    @contextlib.contextmanager
-    def open(self, path):
-        try:
-            self._qemu_open_path(path)
-            yield self
-        finally:
-            self.close()
-
     def _get_sectors(self, offset, size):
         start_sector = offset >> qemu.BDRV_SECTOR_BITS
         return (start_sector,
@@ -109,6 +102,19 @@ class QEMUDiskImageReader(object):
         return (ctypes.c_ubyte*read_size).from_address(self._buf)
 
 
+class QEMUDiskImageReader(object):
+    @contextlib.contextmanager
+    def open(self, path):
+        impl = None
+        try:
+            impl = QEMUDiskImageReaderImpl(path)
+            impl._open()
+            yield impl
+        finally:
+            if impl:
+                impl.close()
+
+
 def _qemu_init():
     error = ctypes.POINTER(qemu.Error)()