Преглед изворни кода

Allow get_pty

on some systems, running sudo over SSH requires you have a PTY
Gabriel Adrian Samfira пре 7 година
родитељ
комит
43f341acaa
3 измењених фајлова са 114 додато и 52 уклоњено
  1. 110 50
      coriolis/providers/replicator.py
  2. 1 0
      coriolis/secrets.py
  3. 3 2
      coriolis/utils.py

+ 110 - 50
coriolis/providers/replicator.py

@@ -39,8 +39,10 @@ CONF.register_opts(replicator_opts, 'replicator')
 class Client(object):
 
     def __init__(self, ip, port, credentials, ssh_conn_info,
-                 event_handler, use_compression=False):
+                 event_handler, use_compression=False,
+                 use_tunnel=False):
         self._ip = ip
+        self._use_tunnel = use_tunnel
         self._port = port
         self._event_manager = event_handler
         self._creds = credentials
@@ -88,23 +90,28 @@ class Client(object):
         Attempt to connect to the IP/port pair. If direct connection
         fails, set up a SSH tunnel and attempt a connection through that.
         """
-        self._event_manager.progress_update(
-            "Testing direct connection to replicator (%s:%s)" % (
-                self._ip, self._port))
-        try:
-            utils.wait_for_port_connectivity(
-                self._ip, self._port, max_wait=2)
-            return
-        except BaseException as err:
-            LOG.debug("failed to connect to %s:%s Error: %s "
-                      "Trying tunneled connection" % (
-                          self._ip, self._port, err))
-            self._event_manager.progress_update(
-                "Direct connection to replicator failed. Setting up tunnel.")
+        if self._use_tunnel:
+            # It was explicitly requested to use a tunnel
             self._setup_tunnel_connection()
+        else:
+            self._event_manager.progress_update(
+                "Testing direct connection to replicator (%s:%s)" % (
+                    self._ip, self._port))
+            try:
+                utils.wait_for_port_connectivity(
+                    self._ip, self._port, max_wait=2)
+                return
+            except BaseException as err:
+                LOG.debug("failed to connect to %s:%s Error: %s "
+                          "Trying tunneled connection" % (
+                              self._ip, self._port, err))
+                self._event_manager.progress_update(
+                    "Direct connection to replicator failed. "
+                    "Setting up tunnel.")
+                self._setup_tunnel_connection()
 
         self._event_manager.progress_update(
-            "Testing tunneled connection to replicator (%s:%s)" % (
+            "Testing connection to replicator (%s:%s)" % (
                 self.repl_host, self.repl_port))
         try:
             utils.wait_for_port_connectivity(
@@ -112,7 +119,7 @@ class Client(object):
         except BaseException as err:
             self._tunnel.stop()
             LOG.warning(
-                "failed to connect to replicator through tunnel: %s" % err)
+                "failed to connect to replicator: %s" % err)
             raise
 
     def _get_ssh_tunnel(self):
@@ -156,6 +163,7 @@ class Client(object):
         sess.verify = self._creds["ca_cert"]
         return sess
 
+    @utils.retry_on_error()
     def get_status(self, device=None, brief=True):
         uri = "%s/api/v1/dev/" % (self._base_uri)
         if device is not None:
@@ -167,6 +175,7 @@ class Client(object):
         status.raise_for_status()
         return status.json()
 
+    @utils.retry_on_error()
     def get_chunks(self, device, skip_zeros=False):
         uri = "%s/api/v1/dev/%s/chunks/" % (self._base_uri, device)
         params = {
@@ -176,12 +185,21 @@ class Client(object):
         chunks.raise_for_status()
         return chunks.json()
 
+    @utils.retry_on_error()
+    def get_changes(self, device):
+        uri = "%s/api/v1/dev/%s/chunks/changes/" % (self._base_uri, device)
+        chunks = self._cli.get(uri)
+        chunks.raise_for_status()
+        return chunks.json()
+
+    @utils.retry_on_error()
     def get_disk_size(self, disk):
         diskUri = self.raw_disk_uri(disk)
         info = self._cli.head(diskUri)
         info.raise_for_status()
         return int(info.headers["Content-Length"])
 
+    @utils.retry_on_error()
     def download_chunk(self, disk, chunk):
         diskUri = self.raw_disk_uri(disk)
 
@@ -205,7 +223,7 @@ class Replicator(object):
     def __init__(self, conn_info, event_manager, volumes_info, replica_state,
                  use_compression=False, ignore_mounted=True,
                  hash_method=HASH_METHOD_SHA256, watch_devices=True,
-                 chunk_size=10485760):
+                 chunk_size=10485760, use_tunnel=False):
         self._event_manager = event_manager
         self._repl_state = replica_state
         self._conn_info = conn_info
@@ -223,6 +241,7 @@ class Replicator(object):
         self._stdout = None
         self._stdin = None
         self._stderr = None
+        self._use_tunnel = use_tunnel
 
     def _init_replicator_client(self, credentials):
         """
@@ -237,7 +256,8 @@ class Replicator(object):
             args["ip"], args["port"],
             credentials, ssh_conn_info,
             self._event_manager,
-            use_compression=self._use_compression)
+            use_compression=self._use_compression,
+            use_tunnel=self._use_tunnel)
 
     def _setup_ssh(self):
         args = self._parse_source_ssh_conn_info(
@@ -353,7 +373,8 @@ class Replicator(object):
             if ex.errno != errno.ENOENT:
                 raise
             sftp.put(localPath, tmp)
-            utils.exec_ssh_cmd(ssh, "sudo mv %s %s" % (tmp, remotePath))
+            utils.exec_ssh_cmd(
+                ssh, "sudo mv %s %s" % (tmp, remotePath), get_pty=True)
         finally:
             sftp.close()
 
@@ -362,16 +383,29 @@ class Replicator(object):
         local_path = os.path.join(
             utils.get_resources_dir(), 'replicator')
         self._copy_file(ssh, local_path, REPLICATOR_PATH)
-        utils.exec_ssh_cmd(ssh, "sudo chmod +x %s" % REPLICATOR_PATH)
+        utils.exec_ssh_cmd(
+            ssh, "sudo chmod +x %s" % REPLICATOR_PATH, get_pty=True)
 
     def _setup_replicator_user(self, ssh):
         user_exists = utils.exec_ssh_cmd(
             ssh, "getent passwd replicator > /dev/null && echo 1 || echo 0")
         if int(user_exists) == 0:
             utils.exec_ssh_cmd(
-                ssh, "sudo useradd -m -s /bin/bash %s" % REPLICATOR_USERNAME)
+                ssh, "sudo useradd -m -s /bin/bash %s" % REPLICATOR_USERNAME,
+                get_pty=True)
             utils.exec_ssh_cmd(
-                ssh, "sudo usermod -aG disk %s" % REPLICATOR_USERNAME)
+                ssh, "sudo usermod -aG disk %s" % REPLICATOR_USERNAME,
+                get_pty=True)
+
+    def _check_replicator_errors(self):
+        if self._stdout.channel.exit_status_ready():
+            exit_code = self._stdout.channel.recv_exit_status()
+            if exit_code:
+                stderr = self._stderr.read()
+                stdout = self._stdout.read()
+                raise exception.CoriolisException(
+                    "failed to start replicator: stdout: "
+                    "%s; stderr: %s" % (stdout, stderr))
 
     @utils.retry_on_error()
     def _exec_replicator(self, ssh, args, state_file):
@@ -386,7 +420,7 @@ class Replicator(object):
             "sudo chown %(user)s:%(user)s %(config_dir)s" % {
                 "config_dir": self._config_dir,
                 "user": REPLICATOR_USERNAME,
-            })
+            }, get_pty=True)
         cmdline = ("/usr/bin/replicator -certificate-hosts=%(cert_hosts)s "
                    "-config-dir=%(cfgdir)s -hash-method=%(hash_method)s "
                    "-ignore-mounted-disks=%(ignore_mounted)s "
@@ -403,12 +437,11 @@ class Replicator(object):
                        "state_file": state_file,
                        "chunk_size": self._chunk_size,
                    })
-        self._event_manager.progress_update("running %s" % cmdline)
         self._stdin, self._stdout, self._stderr = ssh.exec_command(
             "sudo -u %(username)s -- %(cmdline)s > /tmp/replicator.log" % {
                 "cmdline": cmdline,
                 "username": REPLICATOR_USERNAME,
-            })
+            }, get_pty=True)
         count = 0
         # wait 5 seconds. If process exits, raise
         # TODO(gsamfira): create system service? That should take care of
@@ -416,14 +449,7 @@ class Replicator(object):
         while True:
             if count >= 5:
                 break
-            if self._stdout.channel.exit_status_ready():
-                exit_code = self._stdout.channel.recv_exit_status()
-                if exit_code:
-                    stderr = self._stderr.read()
-                    stdout = self._stdout.read()
-                    raise exception.CoriolisException(
-                        "failed to start replicator: stdout: "
-                        "%s; stderr: %s" % (stdout, stderr))
+            self._check_replicator_errors()
             time.sleep(1)
             count += 1
 
@@ -432,7 +458,9 @@ class Replicator(object):
         # copy the binary and execute it.
         state_file = self._get_replicator_state_file()
         self._copy_file(ssh, state_file, REPLICATOR_STATE)
-        utils.exec_ssh_cmd(ssh, "sudo chmod 755 %s" % REPLICATOR_STATE)
+        utils.exec_ssh_cmd(
+            ssh, "sudo chmod 755 %s" % REPLICATOR_STATE, get_pty=True)
+        os.remove(state_file)
 
         args = self._parse_replicator_conn_info(self._conn_info)
         self._copy_replicator_cmd(ssh)
@@ -468,9 +496,9 @@ class Replicator(object):
             self._config_dir, "ssl/client/client-creds.zip")
 
         utils.exec_ssh_cmd(
-            self._ssh, "sudo cp -f %s /tmp/creds.zip" % zipFile)
+            self._ssh, "sudo cp -f %s /tmp/creds.zip" % zipFile, get_pty=True)
         utils.exec_ssh_cmd(
-            self._ssh, "sudo chmod +r /tmp/creds.zip")
+            self._ssh, "sudo chmod +r /tmp/creds.zip", get_pty=True)
 
         sftp = paramiko.SFTPClient.from_transport(
             self._ssh.get_transport())
@@ -481,6 +509,8 @@ class Replicator(object):
 
         zFile = zipfile.ZipFile(localCertZip)
         zFile.extractall(path=self._cert_dir)
+        zFile.close()
+        os.remove(localCertZip)
         return {
             "client_cert": clientCrt,
             "client_key": clientKey,
@@ -502,6 +532,12 @@ class Replicator(object):
             ret += chunk["length"]
         return ret / units.Gi
 
+    def _find_vol_state(self, name, state):
+        for vol in state:
+            if vol["device-name"] == name:
+                return vol
+        return None
+
     def replicate_disks(self, source_volumes_info, backup_writer):
         """
         Fetch the block diff and send it to the backup_writer.
@@ -511,41 +547,47 @@ class Replicator(object):
         are part of a file or not.
 
         source_volumes_info should be of the following format:
-
-        {
-            "disk_id": the_provider_ID_of_the_volume,
-            "disk_path": /dev/sdb,
-        }
+        [
+            {
+                "disk_id": the_provider_ID_of_the_volume,
+                "disk_path": /dev/sdb,
+            },
+        ]
         """
+        LOG.warning("Source volumes info is: %r" % source_volumes_info)
         state = self._repl_state
         isInitial = False
         if state is None or len(state) == 0:
             isInitial = True
+        curr_state = self._cli.get_status(brief=False)
 
         for volume in source_volumes_info:
-            dst_vol = None
-            for vol in self._volumes_info:
+            dst_vol_idx = None
+            for idx, vol in enumerate(self._volumes_info):
                 if vol["disk_id"] == volume["disk_id"]:
-                    dst_vol = vol
+                    dst_vol_idx = idx
                     break
 
-            if dst_vol is None:
+            if dst_vol_idx is None:
                 raise exception.CoriolisException(
                     "failed to find a coresponding volume in volumes_info"
                     " for %s" % volume["disk_id"])
 
+            dst_vol = self._volumes_info[dst_vol_idx]
+
             devName = volume["disk_path"]
             if devName.startswith('/dev'):
-                devNane = devName[5:]
+                devName = devName[5:]
 
+            state_for_vol = self._find_vol_state(devName, curr_state)
             if isInitial and dst_vol.get("zeroed", False) is True:
                 # This is an initial sync of the disk, and we can
                 # skip zero chunks
                 chunks = self._cli.get_chunks(
-                    devNane, skip_zeros=True)
+                    devName, skip_zeros=True)
             else:
                 # subsequent sync. Get changes.
-                chunks = self._cli.get_changes(devNane)
+                chunks = self._cli.get_changes(devName)
 
             size = self._get_size_from_chunks(chunks)
 
@@ -555,17 +597,35 @@ class Replicator(object):
                 len(chunks), message_format=msg)
 
             total = 0
+
+            @utils.retry_on_error()
+            def _download_or_reconnect(chunk):
+                try:
+                    data = self._cli.download_chunk(
+                        devName, chunk)
+                    return data
+                except BaseException as err:
+                    LOG.error("Error downloading chunk: %r" % err)
+                    try:
+                        self._check_replicator_errors()
+                    except BaseException as err:
+                        LOG.error("replicator is in error: %r" % err)
+                        self.init_replicator()
+                        raise
+                    raise
+
             with backup_writer.open("", volume['disk_id']) as destination:
                 for chunk in chunks:
                     offset = int(chunk["offset"])
                     destination.seek(offset)
-                    data = self._cli.download_chunk(devName, chunk)
+                    data = _download_or_reconnect(chunk)
                     destination.write(data)
                     total += 1
                     self._event_manager.set_percentage_step(
                         perc_step, total)
+            dst_vol["replica_state"] = state_for_vol
 
-        self._repl_state = self._cli.get_status()
+        self._repl_state = curr_state
         return self._repl_state
 
     def _download_full_disk(self, disk, path):

+ 1 - 0
coriolis/secrets.py

@@ -9,6 +9,7 @@ from coriolis import keystone
 
 
 def get_secret(ctxt, secret_ref):
+    keystone.create_trust(ctxt)
     session = keystone.create_keystone_session(ctxt)
     barbican = barbican_client.Client(session=session)
     return json.loads(barbican.secrets.get(secret_ref).payload)

+ 3 - 2
coriolis/utils.py

@@ -172,10 +172,11 @@ def list_ssh_dir(ssh, remote_path):
 
 
 @retry_on_error()
-def exec_ssh_cmd(ssh, cmd, environment=None):
+def exec_ssh_cmd(ssh, cmd, environment=None, get_pty=False):
     LOG.debug("Executing SSH command: %s", cmd)
     LOG.debug("SSH command environment: %s", environment)
-    stdin, stdout, stderr = ssh.exec_command(cmd, environment=environment)
+    stdin, stdout, stderr = ssh.exec_command(
+        cmd, environment=environment, get_pty=get_pty)
     exit_code = stdout.channel.recv_exit_status()
     std_out = stdout.read()
     std_err = stderr.read()