|
@@ -0,0 +1,209 @@
|
|
|
|
|
+"""
|
|
|
|
|
+Provider-agnostic unit tests for the base multipart upload driver
|
|
|
|
|
+(``BaseBucketObject._upload_multipart``).
|
|
|
|
|
+
|
|
|
|
|
+The driver is the engine behind transparent large uploads on providers that do
|
|
|
|
|
+not override it (GCP, OpenStack Swift). Because the mock provider is AWS-backed
|
|
|
|
|
+and AWS overrides the driver with boto3's native uploader, the driver is
|
|
|
|
|
+exercised here directly against in-memory fakes so it has coverage in CI
|
|
|
|
|
+without cloud credentials.
|
|
|
|
|
+"""
|
|
|
|
|
+import threading
|
|
|
|
|
+import unittest
|
|
|
|
|
+from io import BytesIO
|
|
|
|
|
+
|
|
|
|
|
+from cloudbridge.base.resources import BaseBucketObject
|
|
|
|
|
+from cloudbridge.base.resources import BaseMultipartUpload
|
|
|
|
|
+from cloudbridge.base.resources import BaseUploadPart
|
|
|
|
|
+from cloudbridge.interfaces.exceptions import InvalidValueException
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _Recorder:
|
|
|
|
|
+ """Thread-safe sink shared by the original and all cloned fake services."""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ self._lock = threading.Lock()
|
|
|
|
|
+ self.parts = {} # part_number -> bytes
|
|
|
|
|
+ self.services_used = set() # id() of each service that uploaded a part
|
|
|
|
|
+ self.clone_count = 0
|
|
|
|
|
+ self.completed_order = None
|
|
|
|
|
+ self.aborted = False
|
|
|
|
|
+ self.active = 0
|
|
|
|
|
+ self.max_active = 0
|
|
|
|
|
+ self.fail_on_part = None # part_number that should raise
|
|
|
|
|
+
|
|
|
|
|
+ def record_part(self, service, part_number, data):
|
|
|
|
|
+ with self._lock:
|
|
|
|
|
+ self.active += 1
|
|
|
|
|
+ self.max_active = max(self.max_active, self.active)
|
|
|
|
|
+ try:
|
|
|
|
|
+ if self.fail_on_part == part_number:
|
|
|
|
|
+ raise RuntimeError("boom on part %d" % part_number)
|
|
|
|
|
+ # Hold briefly so concurrent uploads genuinely overlap.
|
|
|
|
|
+ time_to_sleep = 0.02
|
|
|
|
|
+ _sleep(time_to_sleep)
|
|
|
|
|
+ with self._lock:
|
|
|
|
|
+ self.parts[part_number] = bytes(data)
|
|
|
|
|
+ self.services_used.add(id(service))
|
|
|
|
|
+ finally:
|
|
|
|
|
+ with self._lock:
|
|
|
|
|
+ self.active -= 1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _sleep(seconds):
|
|
|
|
|
+ # Indirection so the deterministic tests can monkeypatch if needed; a plain
|
|
|
|
|
+ # sleep is fine here and keeps the overlap window small.
|
|
|
|
|
+ threading.Event().wait(seconds)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _FakeService:
|
|
|
|
|
+ def __init__(self, recorder, provider):
|
|
|
|
|
+ self._recorder = recorder
|
|
|
|
|
+ self._provider = provider
|
|
|
|
|
+
|
|
|
|
|
+ def create_multipart_upload(self, bucket, object_name):
|
|
|
|
|
+ return BaseMultipartUpload(self._provider, bucket, object_name, "upl")
|
|
|
|
|
+
|
|
|
|
|
+ def upload_part(self, bucket, upload, part_number, data):
|
|
|
|
|
+ self._recorder.record_part(self, part_number, data)
|
|
|
|
|
+ return BaseUploadPart(part_number, "etag-%d" % part_number)
|
|
|
|
|
+
|
|
|
|
|
+ def complete_multipart_upload(self, bucket, upload, parts):
|
|
|
|
|
+ ordered = sorted(parts, key=lambda p: p.part_number)
|
|
|
|
|
+ self._recorder.completed_order = [p.part_number for p in ordered]
|
|
|
|
|
+ return b"".join(self._recorder.parts[p.part_number] for p in ordered)
|
|
|
|
|
+
|
|
|
|
|
+ def abort_multipart_upload(self, bucket, upload):
|
|
|
|
|
+ self._recorder.aborted = True
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _FakeStorage:
|
|
|
|
|
+ def __init__(self, service):
|
|
|
|
|
+ self._bucket_objects = service
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _FakeProvider:
|
|
|
|
|
+ def __init__(self, recorder):
|
|
|
|
|
+ self._recorder = recorder
|
|
|
|
|
+ self.storage = _FakeStorage(_FakeService(recorder, self))
|
|
|
|
|
+
|
|
|
|
|
+ def clone(self, zone=None):
|
|
|
|
|
+ self._recorder.clone_count += 1
|
|
|
|
|
+ return _FakeProvider(self._recorder)
|
|
|
|
|
+
|
|
|
|
|
+ def _get_config_value(self, key, default_value=None):
|
|
|
|
|
+ return default_value
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class _DriverObject(BaseBucketObject):
|
|
|
|
|
+ """A BaseBucketObject wired to fakes, with a tiny minimum part size so
|
|
|
|
|
+ tests can use small payloads."""
|
|
|
|
|
+
|
|
|
|
|
+ CB_MULTIPART_MIN_PART_SIZE = 1
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, provider, part_size, concurrency):
|
|
|
|
|
+ super(_DriverObject, self).__init__(provider)
|
|
|
|
|
+ self._part_size = part_size
|
|
|
|
|
+ self._concurrency = concurrency
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def id(self):
|
|
|
|
|
+ return "obj"
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def name(self):
|
|
|
|
|
+ return "obj"
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def bucket(self):
|
|
|
|
|
+ return "BUCKET"
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def _multipart_part_size(self):
|
|
|
|
|
+ return self._part_size
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def _multipart_max_concurrency(self):
|
|
|
|
|
+ return self._concurrency
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class MultipartDriverTestCase(unittest.TestCase):
|
|
|
|
|
+
|
|
|
|
|
+ def _driver(self, recorder, part_size, concurrency):
|
|
|
|
|
+ return _DriverObject(_FakeProvider(recorder), part_size, concurrency)
|
|
|
|
|
+
|
|
|
|
|
+ def test_reassembles_payload_in_order(self):
|
|
|
|
|
+ recorder = _Recorder()
|
|
|
|
|
+ driver = self._driver(recorder, part_size=4, concurrency=3)
|
|
|
|
|
+ content = b"abcdefghijABCDEFGHIJ0123456789x" # 31 bytes -> 8 parts
|
|
|
|
|
+ result = driver._upload_multipart(BytesIO(content))
|
|
|
|
|
+ self.assertEqual(result, content)
|
|
|
|
|
+ self.assertEqual(recorder.completed_order, list(range(1, 9)))
|
|
|
|
|
+ # Final part is the short remainder (3 bytes).
|
|
|
|
|
+ self.assertEqual(recorder.parts[8], content[28:])
|
|
|
|
|
+
|
|
|
|
|
+ def test_handles_short_reads_without_undersized_parts(self):
|
|
|
|
|
+ recorder = _Recorder()
|
|
|
|
|
+ driver = self._driver(recorder, part_size=8, concurrency=2)
|
|
|
|
|
+
|
|
|
|
|
+ class _DripStream:
|
|
|
|
|
+ """Returns at most 3 bytes per read to simulate a socket-like
|
|
|
|
|
+ stream; the driver must coalesce reads up to the part size."""
|
|
|
|
|
+ def __init__(self, data):
|
|
|
|
|
+ self._buf = BytesIO(data)
|
|
|
|
|
+
|
|
|
|
|
+ def read(self, size):
|
|
|
|
|
+ return self._buf.read(min(size, 3))
|
|
|
|
|
+
|
|
|
|
|
+ content = bytes(range(20)) # 20 bytes, part_size 8 -> 8,8,4
|
|
|
|
|
+ result = driver._upload_multipart(_DripStream(content))
|
|
|
|
|
+ self.assertEqual(result, content)
|
|
|
|
|
+ self.assertEqual(len(recorder.parts[1]), 8)
|
|
|
|
|
+ self.assertEqual(len(recorder.parts[2]), 8)
|
|
|
|
|
+ self.assertEqual(len(recorder.parts[3]), 4)
|
|
|
|
|
+
|
|
|
|
|
+ def test_uploads_parts_concurrently_via_cloned_services(self):
|
|
|
|
|
+ recorder = _Recorder()
|
|
|
|
|
+ concurrency = 4
|
|
|
|
|
+ driver = self._driver(recorder, part_size=1, concurrency=concurrency)
|
|
|
|
|
+ # 12 parts (one byte each) across a pool of 4 clones.
|
|
|
|
|
+ content = b"0123456789ab"
|
|
|
|
|
+ driver._upload_multipart(BytesIO(content))
|
|
|
|
|
+
|
|
|
|
|
+ # A clone per worker, reused across parts.
|
|
|
|
|
+ self.assertEqual(recorder.clone_count, concurrency)
|
|
|
|
|
+ self.assertEqual(len(recorder.services_used), concurrency)
|
|
|
|
|
+ # Real parallelism happened, bounded by the configured concurrency.
|
|
|
|
|
+ self.assertGreater(recorder.max_active, 1)
|
|
|
|
|
+ self.assertLessEqual(recorder.max_active, concurrency)
|
|
|
|
|
+
|
|
|
|
|
+ def test_single_concurrency_does_not_clone(self):
|
|
|
|
|
+ recorder = _Recorder()
|
|
|
|
|
+ driver = self._driver(recorder, part_size=4, concurrency=1)
|
|
|
|
|
+ driver._upload_multipart(BytesIO(b"abcdefghij"))
|
|
|
|
|
+ self.assertEqual(recorder.clone_count, 0)
|
|
|
|
|
+ self.assertEqual(recorder.max_active, 1)
|
|
|
|
|
+
|
|
|
|
|
+ def test_aborts_and_raises_on_part_failure(self):
|
|
|
|
|
+ recorder = _Recorder()
|
|
|
|
|
+ recorder.fail_on_part = 2
|
|
|
|
|
+ driver = self._driver(recorder, part_size=4, concurrency=3)
|
|
|
|
|
+ with self.assertRaises(Exception):
|
|
|
|
|
+ driver._upload_multipart(BytesIO(b"abcdefghijklmnop"))
|
|
|
|
|
+ self.assertTrue(recorder.aborted)
|
|
|
|
|
+
|
|
|
|
|
+ def test_part_size_below_minimum_raises(self):
|
|
|
|
|
+ recorder = _Recorder()
|
|
|
|
|
+ # Use the real minimum here (not the test override) by going through a
|
|
|
|
|
+ # driver whose part size is one below the production floor.
|
|
|
|
|
+ driver = _DriverObject(_FakeProvider(recorder), part_size=4,
|
|
|
|
|
+ concurrency=2)
|
|
|
|
|
+ driver.CB_MULTIPART_MIN_PART_SIZE = 5
|
|
|
|
|
+ with self.assertRaises(InvalidValueException):
|
|
|
|
|
+ driver._upload_multipart(BytesIO(b"abc"))
|
|
|
|
|
+ # Nothing should have been created before validation failed.
|
|
|
|
|
+ self.assertEqual(recorder.completed_order, None)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ unittest.main()
|