test_multipart_driver.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. """
  2. Provider-agnostic unit tests for the base multipart upload driver
  3. (``BaseBucketObject._upload_multipart``).
  4. The driver is the engine behind transparent large uploads on providers that do
  5. not override it (GCP, OpenStack Swift). Because the mock provider is AWS-backed
  6. and AWS overrides the driver with boto3's native uploader, the driver is
  7. exercised here directly against in-memory fakes so it has coverage in CI
  8. without cloud credentials.
  9. """
  10. import threading
  11. import unittest
  12. from io import BytesIO
  13. from cloudbridge.base.resources import BaseBucketObject
  14. from cloudbridge.base.resources import BaseMultipartUpload
  15. from cloudbridge.base.resources import BaseUploadPart
  16. from cloudbridge.interfaces.exceptions import InvalidValueException
  17. class _Recorder:
  18. """Thread-safe sink shared by the original and all cloned fake services."""
  19. def __init__(self):
  20. self._lock = threading.Lock()
  21. self.parts = {} # part_number -> bytes
  22. self.services_used = set() # id() of each service that uploaded a part
  23. self.clone_count = 0
  24. self.completed_order = None
  25. self.aborted = False
  26. self.active = 0
  27. self.max_active = 0
  28. self.fail_on_part = None # part_number that should raise
  29. def record_part(self, service, part_number, data):
  30. with self._lock:
  31. self.active += 1
  32. self.max_active = max(self.max_active, self.active)
  33. try:
  34. if self.fail_on_part == part_number:
  35. raise RuntimeError("boom on part %d" % part_number)
  36. # Hold briefly so concurrent uploads genuinely overlap.
  37. time_to_sleep = 0.02
  38. _sleep(time_to_sleep)
  39. with self._lock:
  40. self.parts[part_number] = bytes(data)
  41. self.services_used.add(id(service))
  42. finally:
  43. with self._lock:
  44. self.active -= 1
  45. def _sleep(seconds):
  46. # Indirection so the deterministic tests can monkeypatch if needed; a plain
  47. # sleep is fine here and keeps the overlap window small.
  48. threading.Event().wait(seconds)
  49. class _FakeService:
  50. def __init__(self, recorder, provider):
  51. self._recorder = recorder
  52. self._provider = provider
  53. def create_multipart_upload(self, bucket, object_name):
  54. return BaseMultipartUpload(self._provider, bucket, object_name, "upl")
  55. def upload_part(self, bucket, upload, part_number, data):
  56. self._recorder.record_part(self, part_number, data)
  57. return BaseUploadPart(part_number, "etag-%d" % part_number)
  58. def complete_multipart_upload(self, bucket, upload, parts):
  59. ordered = sorted(parts, key=lambda p: p.part_number)
  60. self._recorder.completed_order = [p.part_number for p in ordered]
  61. return b"".join(self._recorder.parts[p.part_number] for p in ordered)
  62. def abort_multipart_upload(self, bucket, upload):
  63. self._recorder.aborted = True
  64. class _FakeStorage:
  65. def __init__(self, service):
  66. self._bucket_objects = service
  67. class _FakeProvider:
  68. def __init__(self, recorder):
  69. self._recorder = recorder
  70. self.storage = _FakeStorage(_FakeService(recorder, self))
  71. def clone(self, zone=None):
  72. self._recorder.clone_count += 1
  73. return _FakeProvider(self._recorder)
  74. def _get_config_value(self, key, default_value=None):
  75. return default_value
  76. class _DriverObject(BaseBucketObject):
  77. """A BaseBucketObject wired to fakes, with a tiny minimum part size so
  78. tests can use small payloads."""
  79. CB_MULTIPART_MIN_PART_SIZE = 1
  80. def __init__(self, provider, part_size, concurrency):
  81. super(_DriverObject, self).__init__(provider)
  82. self._part_size = part_size
  83. self._concurrency = concurrency
  84. @property
  85. def id(self):
  86. return "obj"
  87. @property
  88. def name(self):
  89. return "obj"
  90. @property
  91. def bucket(self):
  92. return "BUCKET"
  93. @property
  94. def _multipart_part_size(self):
  95. return self._part_size
  96. @property
  97. def _multipart_max_concurrency(self):
  98. return self._concurrency
  99. class MultipartDriverTestCase(unittest.TestCase):
  100. def _driver(self, recorder, part_size, concurrency):
  101. return _DriverObject(_FakeProvider(recorder), part_size, concurrency)
  102. def test_reassembles_payload_in_order(self):
  103. recorder = _Recorder()
  104. driver = self._driver(recorder, part_size=4, concurrency=3)
  105. content = b"abcdefghijABCDEFGHIJ0123456789x" # 31 bytes -> 8 parts
  106. result = driver._upload_multipart(BytesIO(content))
  107. self.assertEqual(result, content)
  108. self.assertEqual(recorder.completed_order, list(range(1, 9)))
  109. # Final part is the short remainder (3 bytes).
  110. self.assertEqual(recorder.parts[8], content[28:])
  111. def test_handles_short_reads_without_undersized_parts(self):
  112. recorder = _Recorder()
  113. driver = self._driver(recorder, part_size=8, concurrency=2)
  114. class _DripStream:
  115. """Returns at most 3 bytes per read to simulate a socket-like
  116. stream; the driver must coalesce reads up to the part size."""
  117. def __init__(self, data):
  118. self._buf = BytesIO(data)
  119. def read(self, size):
  120. return self._buf.read(min(size, 3))
  121. content = bytes(range(20)) # 20 bytes, part_size 8 -> 8,8,4
  122. result = driver._upload_multipart(_DripStream(content))
  123. self.assertEqual(result, content)
  124. self.assertEqual(len(recorder.parts[1]), 8)
  125. self.assertEqual(len(recorder.parts[2]), 8)
  126. self.assertEqual(len(recorder.parts[3]), 4)
  127. def test_uploads_parts_concurrently_via_cloned_services(self):
  128. recorder = _Recorder()
  129. concurrency = 4
  130. driver = self._driver(recorder, part_size=1, concurrency=concurrency)
  131. # 12 parts (one byte each) across a pool of 4 clones.
  132. content = b"0123456789ab"
  133. driver._upload_multipart(BytesIO(content))
  134. # A clone per worker, reused across parts.
  135. self.assertEqual(recorder.clone_count, concurrency)
  136. self.assertEqual(len(recorder.services_used), concurrency)
  137. # Real parallelism happened, bounded by the configured concurrency.
  138. self.assertGreater(recorder.max_active, 1)
  139. self.assertLessEqual(recorder.max_active, concurrency)
  140. def test_single_concurrency_does_not_clone(self):
  141. recorder = _Recorder()
  142. driver = self._driver(recorder, part_size=4, concurrency=1)
  143. driver._upload_multipart(BytesIO(b"abcdefghij"))
  144. self.assertEqual(recorder.clone_count, 0)
  145. self.assertEqual(recorder.max_active, 1)
  146. def test_aborts_and_raises_on_part_failure(self):
  147. recorder = _Recorder()
  148. recorder.fail_on_part = 2
  149. driver = self._driver(recorder, part_size=4, concurrency=3)
  150. with self.assertRaises(Exception):
  151. driver._upload_multipart(BytesIO(b"abcdefghijklmnop"))
  152. self.assertTrue(recorder.aborted)
  153. def test_part_size_below_minimum_raises(self):
  154. recorder = _Recorder()
  155. # Use the real minimum here (not the test override) by going through a
  156. # driver whose part size is one below the production floor.
  157. driver = _DriverObject(_FakeProvider(recorder), part_size=4,
  158. concurrency=2)
  159. driver.CB_MULTIPART_MIN_PART_SIZE = 5
  160. with self.assertRaises(InvalidValueException):
  161. driver._upload_multipart(BytesIO(b"abc"))
  162. # Nothing should have been created before validation failed.
  163. self.assertEqual(recorder.completed_order, None)
  164. if __name__ == "__main__":
  165. unittest.main()