test_multipart_driver.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """
  2. Provider-agnostic unit tests for the base multipart upload driver
  3. (``BaseBucketObject._upload_multipart``).
  4. The driver is the engine behind transparent large uploads on providers that do
  5. not override it (GCP, OpenStack Swift). Because the mock provider is AWS-backed
  6. and AWS overrides the driver with boto3's native uploader, the driver is
  7. exercised here directly against in-memory fakes so it has coverage in CI
  8. without cloud credentials.
  9. """
  10. import threading
  11. import unittest
  12. from io import BytesIO
  13. from cloudbridge.base.resources import BaseBucketObject
  14. from cloudbridge.base.resources import BaseMultipartUpload
  15. from cloudbridge.base.resources import BaseUploadPart
  16. from cloudbridge.interfaces.exceptions import InvalidValueException
  17. from cloudbridge.interfaces.resources import UploadConfig
  18. class _Recorder:
  19. """Thread-safe sink shared by the original and all cloned fake services."""
  20. def __init__(self):
  21. self._lock = threading.Lock()
  22. self.parts = {} # part_number -> bytes
  23. self.services_used = set() # id() of each service that uploaded a part
  24. self.clone_count = 0
  25. self.completed_order = None
  26. self.aborted = False
  27. self.active = 0
  28. self.max_active = 0
  29. self.fail_on_part = None # part_number that should raise
  30. def record_part(self, service, part_number, data):
  31. with self._lock:
  32. self.active += 1
  33. self.max_active = max(self.max_active, self.active)
  34. try:
  35. if self.fail_on_part == part_number:
  36. raise RuntimeError("boom on part %d" % part_number)
  37. # Hold briefly so concurrent uploads genuinely overlap.
  38. time_to_sleep = 0.02
  39. _sleep(time_to_sleep)
  40. with self._lock:
  41. self.parts[part_number] = bytes(data)
  42. self.services_used.add(id(service))
  43. finally:
  44. with self._lock:
  45. self.active -= 1
  46. def _sleep(seconds):
  47. # Indirection so the deterministic tests can monkeypatch if needed; a plain
  48. # sleep is fine here and keeps the overlap window small.
  49. threading.Event().wait(seconds)
  50. class _FakeService:
  51. def __init__(self, recorder, provider):
  52. self._recorder = recorder
  53. self._provider = provider
  54. def create_multipart_upload(self, bucket, object_name):
  55. return BaseMultipartUpload(self._provider, bucket, object_name, "upl")
  56. def upload_part(self, bucket, upload, part_number, data):
  57. self._recorder.record_part(self, part_number, data)
  58. return BaseUploadPart(part_number, "etag-%d" % part_number)
  59. def complete_multipart_upload(self, bucket, upload, parts):
  60. ordered = sorted(parts, key=lambda p: p.part_number)
  61. self._recorder.completed_order = [p.part_number for p in ordered]
  62. return b"".join(self._recorder.parts[p.part_number] for p in ordered)
  63. def abort_multipart_upload(self, bucket, upload):
  64. self._recorder.aborted = True
  65. class _FakeStorage:
  66. def __init__(self, service):
  67. self._bucket_objects = service
  68. class _FakeProvider:
  69. def __init__(self, recorder):
  70. self._recorder = recorder
  71. self.storage = _FakeStorage(_FakeService(recorder, self))
  72. def clone(self, zone=None):
  73. self._recorder.clone_count += 1
  74. return _FakeProvider(self._recorder)
  75. def _get_config_value(self, key, default_value=None):
  76. return default_value
  77. class _DriverObject(BaseBucketObject):
  78. """A BaseBucketObject wired to fakes, with a tiny minimum part size so
  79. tests can use small payloads."""
  80. CB_MULTIPART_MIN_PART_SIZE = 1
  81. def __init__(self, provider, part_size, concurrency):
  82. super(_DriverObject, self).__init__(provider)
  83. self._part_size = part_size
  84. self._concurrency = concurrency
  85. @property
  86. def id(self):
  87. return "obj"
  88. @property
  89. def name(self):
  90. return "obj"
  91. @property
  92. def bucket(self):
  93. return "BUCKET"
  94. def _multipart_part_size(self, config=None):
  95. if config is not None and config.part_size is not None:
  96. return config.part_size
  97. return self._part_size
  98. def _multipart_max_concurrency(self, config=None):
  99. if config is not None and config.max_concurrency is not None:
  100. return config.max_concurrency
  101. return self._concurrency
  102. class MultipartDriverTestCase(unittest.TestCase):
  103. def _driver(self, recorder, part_size, concurrency):
  104. return _DriverObject(_FakeProvider(recorder), part_size, concurrency)
  105. def test_reassembles_payload_in_order(self):
  106. recorder = _Recorder()
  107. driver = self._driver(recorder, part_size=4, concurrency=3)
  108. content = b"abcdefghijABCDEFGHIJ0123456789x" # 31 bytes -> 8 parts
  109. result = driver._upload_multipart(BytesIO(content))
  110. self.assertEqual(result, content)
  111. self.assertEqual(recorder.completed_order, list(range(1, 9)))
  112. # Final part is the short remainder (3 bytes).
  113. self.assertEqual(recorder.parts[8], content[28:])
  114. def test_handles_short_reads_without_undersized_parts(self):
  115. recorder = _Recorder()
  116. driver = self._driver(recorder, part_size=8, concurrency=2)
  117. class _DripStream:
  118. """Returns at most 3 bytes per read to simulate a socket-like
  119. stream; the driver must coalesce reads up to the part size."""
  120. def __init__(self, data):
  121. self._buf = BytesIO(data)
  122. def read(self, size):
  123. return self._buf.read(min(size, 3))
  124. content = bytes(range(20)) # 20 bytes, part_size 8 -> 8,8,4
  125. result = driver._upload_multipart(_DripStream(content))
  126. self.assertEqual(result, content)
  127. self.assertEqual(len(recorder.parts[1]), 8)
  128. self.assertEqual(len(recorder.parts[2]), 8)
  129. self.assertEqual(len(recorder.parts[3]), 4)
  130. def test_uploads_parts_concurrently_via_cloned_services(self):
  131. recorder = _Recorder()
  132. concurrency = 4
  133. driver = self._driver(recorder, part_size=1, concurrency=concurrency)
  134. # 12 parts (one byte each) across a pool of 4 clones.
  135. content = b"0123456789ab"
  136. driver._upload_multipart(BytesIO(content))
  137. # A clone per worker, reused across parts.
  138. self.assertEqual(recorder.clone_count, concurrency)
  139. self.assertEqual(len(recorder.services_used), concurrency)
  140. # Real parallelism happened, bounded by the configured concurrency.
  141. self.assertGreater(recorder.max_active, 1)
  142. self.assertLessEqual(recorder.max_active, concurrency)
  143. def test_single_concurrency_does_not_clone(self):
  144. recorder = _Recorder()
  145. driver = self._driver(recorder, part_size=4, concurrency=1)
  146. driver._upload_multipart(BytesIO(b"abcdefghij"))
  147. self.assertEqual(recorder.clone_count, 0)
  148. self.assertEqual(recorder.max_active, 1)
  149. def test_per_call_config_overrides_concurrency(self):
  150. recorder = _Recorder()
  151. # The driver's default concurrency is 1 (no clones); a per-call config
  152. # bumps it to 3, engaging the clone-per-worker pool.
  153. driver = self._driver(recorder, part_size=1, concurrency=1)
  154. driver._upload_multipart(BytesIO(b"0123456789ab"),
  155. UploadConfig(max_concurrency=3))
  156. self.assertEqual(recorder.clone_count, 3)
  157. self.assertGreater(recorder.max_active, 1)
  158. self.assertLessEqual(recorder.max_active, 3)
  159. def test_aborts_and_raises_on_part_failure(self):
  160. recorder = _Recorder()
  161. recorder.fail_on_part = 2
  162. driver = self._driver(recorder, part_size=4, concurrency=3)
  163. with self.assertRaises(Exception):
  164. driver._upload_multipart(BytesIO(b"abcdefghijklmnop"))
  165. self.assertTrue(recorder.aborted)
  166. def test_part_size_below_minimum_raises(self):
  167. recorder = _Recorder()
  168. # Use the real minimum here (not the test override) by going through a
  169. # driver whose part size is one below the production floor.
  170. driver = _DriverObject(_FakeProvider(recorder), part_size=4,
  171. concurrency=2)
  172. driver.CB_MULTIPART_MIN_PART_SIZE = 5
  173. with self.assertRaises(InvalidValueException):
  174. driver._upload_multipart(BytesIO(b"abc"))
  175. # Nothing should have been created before validation failed.
  176. self.assertEqual(recorder.completed_order, None)
  177. if __name__ == "__main__":
  178. unittest.main()