utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright 2016 Cloudbase Solutions Srl
  2. # All Rights Reserved.
  3. import functools
  4. import hashlib
  5. import io
  6. import json
  7. import os
  8. import pickle
  9. import re
  10. import socket
  11. import subprocess
  12. import time
  13. import traceback
  14. import OpenSSL
  15. from oslo_config import cfg
  16. from oslo_log import log as logging
  17. from oslo_serialization import jsonutils
  18. import paramiko
  19. from coriolis import constants
  20. from coriolis import exception
  21. from coriolis import secrets
  22. opts = [
  23. cfg.StrOpt('qemu_img_path',
  24. default='qemu-img',
  25. help='The path of the qemu-img tool.'),
  26. ]
  27. CONF = cfg.CONF
  28. logging.register_options(CONF)
  29. CONF.register_opts(opts)
  30. LOG = logging.getLogger(__name__)
  31. def setup_logging():
  32. logging.setup(CONF, 'coriolis')
  33. def ignore_exceptions(func):
  34. @functools.wraps(func)
  35. def _ignore_exceptions(*args, **kwargs):
  36. try:
  37. return func(*args, **kwargs)
  38. except Exception as ex:
  39. LOG.exception(ex)
  40. return _ignore_exceptions
  41. def get_single_result(lis):
  42. """ Indexes the head of a single element list.
  43. Raises a KeyError if the list is empty or its length is greater than 1.
  44. """
  45. if len(lis) == 0:
  46. raise KeyError("Result list is empty.")
  47. elif len(lis) > 1:
  48. raise KeyError("More than one result in list: '%s'" % lis)
  49. return lis[0]
  50. def retry_on_error(max_attempts=5, sleep_seconds=0,
  51. terminal_exceptions=[]):
  52. def _retry_on_error(func):
  53. @functools.wraps(func)
  54. def _exec_retry(*args, **kwargs):
  55. i = 0
  56. while True:
  57. try:
  58. return func(*args, **kwargs)
  59. except KeyboardInterrupt as ex:
  60. LOG.debug("Got a KeyboardInterrupt, skip retrying")
  61. LOG.exception(ex)
  62. raise
  63. except Exception as ex:
  64. if any([isinstance(ex, tex)
  65. for tex in terminal_exceptions]):
  66. raise
  67. i += 1
  68. if i < max_attempts:
  69. LOG.warn("Exception occurred, retrying: %s", ex)
  70. time.sleep(sleep_seconds)
  71. else:
  72. raise
  73. return _exec_retry
  74. return _retry_on_error
  75. def get_udev_net_rules(net_ifaces_info):
  76. content = ""
  77. for name, mac_address in net_ifaces_info:
  78. content += ('SUBSYSTEM=="net", ACTION=="add", DRIVERS=="?*", '
  79. 'ATTR{address}=="%(mac_address)s", NAME="%(name)s"\n' %
  80. {"name": name, "mac_address": mac_address.lower()})
  81. return content
  82. def parse_os_release(ssh):
  83. os_release_info = exec_ssh_cmd(
  84. ssh, "[ -f '/etc/os-release' ] && cat /etc/os-release || true").decode()
  85. info = {}
  86. for line in os_release_info.splitlines():
  87. if "=" not in line:
  88. continue
  89. k, v = line.split("=")
  90. info[k] = v.strip('"')
  91. if info.get("ID") and info.get("VERSION_ID"):
  92. return (info.get("ID"), info.get("VERSION_ID"))
  93. def parse_lsb_release(ssh):
  94. os_release_info = exec_ssh_cmd(
  95. ssh, "[ -f '/etc/os-release' ] && cat /etc/os-release || true").decode()
  96. out = exec_ssh_cmd(ssh, "lsb_release -a || true").decode()
  97. dist_id = re.findall('^Distributor ID:\s(.*)$', out, re.MULTILINE)
  98. release = re.findall('^Release:\s(.*)$', out, re.MULTILINE)
  99. if dist_id and release:
  100. return (dist_id[0], release[0])
  101. def get_linux_os_info(ssh):
  102. info = parse_os_release(ssh)
  103. if info is None:
  104. #fall back to lsb_release
  105. return parse_lsb_release(ssh)
  106. return info
  107. @retry_on_error()
  108. def test_ssh_path(ssh, remote_path):
  109. sftp = ssh.open_sftp()
  110. try:
  111. sftp.stat(remote_path)
  112. return True
  113. except IOError as ex:
  114. if ex.args[0] == 2:
  115. return False
  116. raise
  117. @retry_on_error()
  118. def read_ssh_file(ssh, remote_path):
  119. sftp = ssh.open_sftp()
  120. return sftp.open(remote_path, 'rb').read()
  121. @retry_on_error()
  122. def write_ssh_file(ssh, remote_path, content):
  123. sftp = ssh.open_sftp()
  124. sftp.open(remote_path, 'wb').write(content)
  125. @retry_on_error()
  126. def list_ssh_dir(ssh, remote_path):
  127. sftp = ssh.open_sftp()
  128. return sftp.listdir(remote_path)
  129. @retry_on_error()
  130. def exec_ssh_cmd(ssh, cmd):
  131. LOG.debug("Executing SSH command: %s", cmd)
  132. stdin, stdout, stderr = ssh.exec_command(cmd)
  133. exit_code = stdout.channel.recv_exit_status()
  134. std_out = stdout.read()
  135. std_err = stderr.read()
  136. if exit_code:
  137. raise exception.CoriolisException(
  138. "Command \"%s\" failed with exit code: %s\n"
  139. "stdout: %s\nstd_err: %s" %
  140. (cmd, exit_code, std_out, std_err))
  141. return std_out
  142. def exec_ssh_cmd_chroot(ssh, chroot_dir, cmd):
  143. return exec_ssh_cmd(ssh, "sudo chroot %s %s" % (chroot_dir, cmd))
  144. def check_fs(ssh, fs_type, dev_path):
  145. try:
  146. out = exec_ssh_cmd(
  147. ssh, "sudo fsck -p -t %s %s" % (fs_type, dev_path)).decode()
  148. LOG.debug("File system checked:\n%s", out)
  149. except Exception as ex:
  150. LOG.warn("Checking file system returned an error:\n%s", str(ex))
  151. def _check_port_open(host, port):
  152. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  153. try:
  154. s.settimeout(1)
  155. s.connect((host, port))
  156. return True
  157. except (ConnectionRefusedError, socket.timeout, OSError):
  158. return False
  159. finally:
  160. s.close()
  161. def wait_for_port_connectivity(address, port, max_wait=300):
  162. i = 0
  163. while not _check_port_open(address, port) and i < max_wait:
  164. time.sleep(1)
  165. i += 1
  166. if i == max_wait:
  167. raise exception.CoriolisException("Connection failed on port %s" %
  168. port)
  169. def exec_process(args):
  170. p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  171. std_out, std_err = p.communicate()
  172. if p.returncode:
  173. raise exception.CoriolisException(
  174. "Command \"%s\" failed with exit code: %s\nstdout: %s\nstd_err: %s"
  175. % (args, p.returncode, std_out, std_err))
  176. return std_out
  177. def get_disk_info(disk_path):
  178. out = exec_process([CONF.qemu_img_path, 'info', '--output=json',
  179. disk_path])
  180. disk_info = json.loads(out.decode())
  181. if disk_info["format"] == "vpc":
  182. disk_info["format"] = constants.DISK_FORMAT_VHD
  183. return disk_info
  184. def convert_disk_format(disk_path, target_disk_path, target_format,
  185. preallocated=False):
  186. allocation_args = []
  187. if preallocated:
  188. if target_format != constants.DISK_FORMAT_VHD:
  189. raise NotImplementedError(
  190. "Preallocation is supported only for the VHD format.")
  191. allocation_args = ['-o', 'subformat=fixed']
  192. if target_format == constants.DISK_FORMAT_VHD:
  193. target_format = "vpc"
  194. args = ([CONF.qemu_img_path, 'convert', '-O', target_format] +
  195. allocation_args +
  196. [disk_path, target_disk_path])
  197. try:
  198. exec_process(args)
  199. except Exception:
  200. ignore_exceptions(os.remove)(target_disk_path)
  201. raise
  202. def get_hostname():
  203. return socket.gethostname()
  204. def get_exception_details():
  205. return traceback.format_exc()
  206. def walk_class_hierarchy(clazz, encountered=None):
  207. """Walk class hierarchy, yielding most derived classes first."""
  208. if not encountered:
  209. encountered = []
  210. for subclass in clazz.__subclasses__():
  211. if subclass not in encountered:
  212. encountered.append(subclass)
  213. # drill down to leaves first
  214. for subsubclass in walk_class_hierarchy(subclass, encountered):
  215. yield subsubclass
  216. yield subclass
  217. def get_ssl_cert_thumbprint(context, host, port=443, digest_algorithm="sha1"):
  218. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  219. ssl_sock = context.wrap_socket(sock, server_hostname=host)
  220. ssl_sock.connect((host, port))
  221. # binary_form is the only option when the certificate is not validated
  222. cert = ssl_sock.getpeercert(binary_form=True)
  223. sock.close()
  224. x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert)
  225. return x509.digest('sha1').decode()
  226. def get_resources_dir():
  227. return os.path.join(
  228. os.path.dirname(os.path.abspath(__file__)), "resources")
  229. def serialize_key(key, password=None):
  230. key_io = io.StringIO()
  231. key.write_private_key(key_io, password)
  232. return key_io.getvalue()
  233. def deserialize_key(key_bytes, password=None):
  234. key_io = io.StringIO(key_bytes)
  235. return paramiko.RSAKey.from_private_key(key_io, password)
  236. def is_serializable(obj):
  237. pickle.dumps(obj)
  238. def to_dict(obj, max_depth=10):
  239. # jsonutils.dumps() has a max_depth of 3 by default
  240. def _to_primitive(value, convert_instances=False,
  241. convert_datetime=True, level=0,
  242. max_depth=max_depth):
  243. return jsonutils.to_primitive(
  244. value, convert_instances, convert_datetime, level, max_depth)
  245. return jsonutils.loads(jsonutils.dumps(obj, default=_to_primitive))
  246. def topological_graph_sorting(items, id="id", depends_on="depends_on",
  247. sort_key=None):
  248. """
  249. Kahn's algorithm
  250. """
  251. if sort_key:
  252. # Sort siblings
  253. items = sorted(items, key=lambda t: t[sort_key], reverse=True)
  254. a = []
  255. for i in items:
  256. a.append({"id": i[id],
  257. "depends_on": list(i[depends_on] or []),
  258. "item": i})
  259. s = []
  260. l = []
  261. for n in a:
  262. if not n["depends_on"]:
  263. s.append(n)
  264. while s:
  265. n = s.pop()
  266. l.append(n["item"])
  267. for m in a:
  268. if n["id"] in m["depends_on"]:
  269. m["depends_on"].remove(n["id"])
  270. if not m["depends_on"]:
  271. s.append(m)
  272. if len(l) != len(a):
  273. raise ValueError("The graph contains cycles")
  274. return l
  275. def load_class(class_path):
  276. LOG.debug('Loading class \'%s\'' % class_path)
  277. parts = class_path.rsplit('.', 1)
  278. module = __import__(parts[0], fromlist=parts[1])
  279. return getattr(module, parts[1])
  280. def check_md5(data, md5):
  281. m = hashlib.md5()
  282. m.update(data)
  283. new_md5 = m.hexdigest()
  284. if new_md5 != md5:
  285. raise exception.CoriolisException("MD5 check failed")
  286. def get_secret_connection_info(ctxt, connection_info):
  287. secret_ref = connection_info.get("secret_ref")
  288. if secret_ref:
  289. LOG.info("Retrieving connection info from secret: %s", secret_ref)
  290. connection_info = secrets.get_secret(ctxt, secret_ref)
  291. return connection_info
  292. def parse_int_value(value):
  293. try:
  294. return int(str(value))
  295. except ValueError:
  296. raise exception.InvalidInput("Invalid integer: %s" % value)