Преглед на файлове

Implement tunneling fallback

In case we can't connect directly to the replicator, tunnel through
using SSH, then attempt the connection again.
Gabriel Adrian Samfira преди 7 години
родител
ревизия
ecad074061
променени са 2 файла, в които са добавени 112 реда и са изтрити 25 реда
  1. 111 25
      coriolis/providers/replicator.py
  2. 1 0
      requirements.txt

+ 111 - 25
coriolis/providers/backup_readers.py → coriolis/providers/replicator.py

@@ -14,6 +14,8 @@ from coriolis import exception
 from oslo_log import log as logging
 from oslo_utils import units
 
+from sshtunnel import SSHTunnelForwarder
+
 LOG = logging.getLogger(__name__)
 
 HASH_METHOD_SHA256 = "sha256"
@@ -21,43 +23,116 @@ HASH_METHOD_XXHASH = "xxhash"
 
 REPLICATOR_PATH = "/usr/bin/replicator"
 REPLICATOR_STATE = "/tmp/replicator_state.json"
+REPLICATOR_USERNAME = "replicator"
+DEFAULT_REPLICATOR_PORT = 4433
 
 
 class Client(object):
 
     def __init__(self, ip, port, credentials, ssh_conn_info,
-                 event_handler, use_gzip=False):
+                 event_handler, use_compression=False):
         self._ip = ip
         self._port = port
+        self._event_manager = event_handler
         self._creds = credentials
-        self._use_gzip = use_gzip
+        self._ssh_conn_info = ssh_conn_info
+        self._use_compression = use_compression
         self._cli = self._get_session()
-        self._base_uri = "https://%s:%s" % (self._ip, self._port)
         self._tunnel = None
+        self._ip_via_tunnel = None
+        self._port_via_tunnel = None
+        self._test_connection()
 
     def __del__(self):
         if self._tunnel is not None:
             try:
                 self._tunnel.stop()
-            except BaseException:
-                # TODO(gsamfira): add logging
-                pass
-
-    def _test_connection(self, ip, port):
+            except BaseException as err:
+                LOG.warning(
+                    "failed to stop tunnel: %s" % err)
+
+    @property
+    def repl_host(self):
+        if self._ip_via_tunnel is not None:
+            return self._ip_via_tunnel
+        return self._ip
+
+    @property
+    def repl_port(self):
+        if self._port_via_tunnel is not None:
+            return self._port_via_tunnel
+        return self._port
+
+    @property
+    def _base_uri(self):
+        return "https://%s:%s" % (self.repl_host, self.repl_port)
+
+    def _setup_tunnel_connection(self):
+        self._tunnel = self._get_ssh_tunnel()
+        self._tunnel.start()
+        host, port = self._tunnel.local_bind_address
+        self._ip_via_tunnel = host
+        self._port_via_tunnel = port
+
+    def _test_connection(self):
         """
-        Attempt to connect to the IP/port pair. This is used in
-        _get_replicator_client().
+        Attempt to connect to the IP/port pair. If direct connection
+        fails, set up a SSH tunnel and attempt a connection through that.
         """
-        pass
+        self._event_manager.progress_update(
+            "Testing %s:%s for connectivity" % (self._ip, self._port))
+        try:
+            utils.wait_for_port_connectivity(
+                self._ip, self._port, max_wait=2)
+            return
+        except BaseException as err:
+            LOG.debug("failed to connect to %s:%s Error: %s "
+                      "Trying tunneled connection" % (
+                          self._ip, self._port, err))
+            self._event_manager.progress_update(
+                "Direct connection to replicator failed. Setting up tunnel.")
+            self._setup_tunnel_connection()
 
-    def _get_ssh_tunnel(self, args):
+        self._event_manager.progress_update(
+            "Testing %s:%s for connectivity" % (
+                self.repl_host, self.repl_port))
+        try:
+            utils.wait_for_port_connectivity(
+                self.repl_host, self.repl_port, max_wait=30)
+        except BaseException as err:
+            self._tunnel.stop()
+            LOG.warning(
+                "failed to connect to replicator through tunnel: %s" % err)
+            raise
+
+    def _get_ssh_tunnel(self):
         """
         gets a SSH tunnel object. Note, this does not start the tunnel,
-        it simply creates the object. It is the job of the caller to call
-        start() and cleanup after itself by calling stop() on the tunnel
-        object
+        it simply creates the object, without actually connecting.
         """
-        pass
+        remote_host = self._ssh_conn_info["hostname"]
+        remote_port = self._ssh_conn_info["port"]
+        remote_user = self._ssh_conn_info["username"]
+        local_host = "127.0.0.1"
+        remote_port = self._ssh_conn_info.get("port", 22)
+
+        pkey = self._ssh_conn_info.get("pkey")
+        password = self._ssh_conn_info.get("password")
+        if any([pkey, password]) is False:
+            raise exception.CoriolisException(
+                "Either password or pkey is required")
+
+        server = SSHTunnelForwarder(
+            (remote_host, remote_port),
+            ssh_username=remote_user,
+            ssh_pkey=pkey,
+            ssh_password=password,
+            # bind to remote replicator port
+            remote_bind_address=(local_host, self._port),
+            # select random port on this end.
+            local_bind_address=(local_host, 0),
+        )
+        return server
 
     def raw_disk_uri(self, disk_name):
         diskUri = "%s/device/%s" % (self._base_uri, disk_name)
@@ -99,13 +174,14 @@ class Client(object):
 
     def download_chunk(self, disk, chunk):
         diskUri = self.raw_disk_uri(disk)
+
         offset = int(chunk["offset"])
         end = offset + int(chunk["length"]) - 1
 
         headers = {
             "Range": "bytes=%s-%s" % (offset, end),
         }
-        if self._use_gzip is False:
+        if self._use_compression is False:
             headers["Accept-encoding"] = "identity"
 
         data = self._cli.get(
@@ -117,7 +193,7 @@ class Client(object):
 class Replicator(object):
 
     def __init__(self, conn_info, event_manager, volumes_info, replica_state,
-                 use_gzip=False, ignore_mounted=True,
+                 use_compression=False, ignore_mounted=True,
                  hash_method=HASH_METHOD_SHA256, watch_devices=True,
                  chunk_size=10485760):
         self._event_manager = event_manager
@@ -127,7 +203,7 @@ class Replicator(object):
         self._cert_dir = None
         self._lock = threading.Lock()
         self._volumes_info = volumes_info
-        self._use_gzip = use_gzip
+        self._use_compression = use_compression
         self._watch_devices = watch_devices
         self._hash_method = hash_method
         self._ignore_mounted = ignore_mounted
@@ -152,7 +228,7 @@ class Replicator(object):
             args["ip"], args["port"],
             credentials, ssh_conn_info,
             self._event_manager,
-            use_gzip=self._use_gzip)
+            use_compression=self._use_compression)
 
     def _setup_ssh(self):
         args = self._parse_source_ssh_conn_info(
@@ -253,7 +329,7 @@ class Replicator(object):
         # The IP should be the same one as the SSH IP.
         # Only the port will differ
         ip = conn_info.get("ip", None)
-        port = conn_info.get("replicator_port", 4433)
+        port = conn_info.get("replicator_port", DEFAULT_REPLICATOR_PORT)
         return {
             "ip": ip,
             "port": port,
@@ -290,8 +366,10 @@ class Replicator(object):
         user_exists = utils.exec_ssh_cmd(
             ssh, "getent passwd replicator > /dev/null && echo 1 || echo 0")
         if int(user_exists) == 0:
-            utils.exec_ssh_cmd(ssh, "sudo useradd -m -s /bin/bash replicator")
-            utils.exec_ssh_cmd(ssh, "sudo usermod -aG disk replicator")
+            utils.exec_ssh_cmd(
+                ssh, "sudo useradd -m -s /bin/bash %s" % REPLICATOR_USERNAME)
+            utils.exec_ssh_cmd(
+                ssh, "sudo usermod -aG disk %s" % REPLICATOR_USERNAME)
 
     @utils.retry_on_error()
     def _exec_replicator(self, ssh, args, state_file):
@@ -303,7 +381,10 @@ class Replicator(object):
             ssh, "mktemp -d").decode().rstrip("\n")
         utils.exec_ssh_cmd(
             ssh,
-            "sudo chown replicator:replicator %s" % self._config_dir)
+            "sudo chown %(user)s:%(user)s %(config_dir)s" % {
+                "config_dir": self._config_dir,
+                "user": REPLICATOR_USERNAME,
+            })
         cmdline = ("/usr/bin/replicator -certificate-hosts=%(cert_hosts)s "
                    "-config-dir=%(cfgdir)s -hash-method=%(hash_method)s "
                    "-ignore-mounted-disks=%(ignore_mounted)s "
@@ -322,9 +403,14 @@ class Replicator(object):
                    })
         self._event_manager.progress_update("running %s" % cmdline)
         self._stdin, self._stdout, self._stderr = ssh.exec_command(
-            "sudo -u replicator -- %s > /tmp/replicator.log" % cmdline)
+            "sudo -u %(username)s -- %(cmdline)s > /tmp/replicator.log" % {
+                "cmdline": cmdline,
+                "username": REPLICATOR_USERNAME,
+            })
         count = 0
         # wait 5 seconds. If process exits, raise
+        # TODO(gsamfira): create system service? That should take care of
+        # restarting the replicator process if it fails.
         while True:
             if count >= 5:
                 break

+ 1 - 0
requirements.txt

@@ -33,3 +33,4 @@ schedule
 strict-rfc3339
 sqlalchemy
 webob
+sshtunnel