| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- """
- Provider-agnostic unit tests for the base multipart upload driver
- (``BaseBucketObject._upload_multipart``).
- The driver is the engine behind transparent large uploads on providers that do
- not override it (GCP, OpenStack Swift). Because the mock provider is AWS-backed
- and AWS overrides the driver with boto3's native uploader, the driver is
- exercised here directly against in-memory fakes so it has coverage in CI
- without cloud credentials.
- """
- import threading
- import unittest
- from io import BytesIO
- from cloudbridge.base.resources import BaseBucketObject
- from cloudbridge.base.resources import BaseMultipartUpload
- from cloudbridge.base.resources import BaseUploadPart
- from cloudbridge.interfaces.exceptions import InvalidValueException
- class _Recorder:
- """Thread-safe sink shared by the original and all cloned fake services."""
- def __init__(self):
- self._lock = threading.Lock()
- self.parts = {} # part_number -> bytes
- self.services_used = set() # id() of each service that uploaded a part
- self.clone_count = 0
- self.completed_order = None
- self.aborted = False
- self.active = 0
- self.max_active = 0
- self.fail_on_part = None # part_number that should raise
- def record_part(self, service, part_number, data):
- with self._lock:
- self.active += 1
- self.max_active = max(self.max_active, self.active)
- try:
- if self.fail_on_part == part_number:
- raise RuntimeError("boom on part %d" % part_number)
- # Hold briefly so concurrent uploads genuinely overlap.
- time_to_sleep = 0.02
- _sleep(time_to_sleep)
- with self._lock:
- self.parts[part_number] = bytes(data)
- self.services_used.add(id(service))
- finally:
- with self._lock:
- self.active -= 1
- def _sleep(seconds):
- # Indirection so the deterministic tests can monkeypatch if needed; a plain
- # sleep is fine here and keeps the overlap window small.
- threading.Event().wait(seconds)
- class _FakeService:
- def __init__(self, recorder, provider):
- self._recorder = recorder
- self._provider = provider
- def create_multipart_upload(self, bucket, object_name):
- return BaseMultipartUpload(self._provider, bucket, object_name, "upl")
- def upload_part(self, bucket, upload, part_number, data):
- self._recorder.record_part(self, part_number, data)
- return BaseUploadPart(part_number, "etag-%d" % part_number)
- def complete_multipart_upload(self, bucket, upload, parts):
- ordered = sorted(parts, key=lambda p: p.part_number)
- self._recorder.completed_order = [p.part_number for p in ordered]
- return b"".join(self._recorder.parts[p.part_number] for p in ordered)
- def abort_multipart_upload(self, bucket, upload):
- self._recorder.aborted = True
- class _FakeStorage:
- def __init__(self, service):
- self._bucket_objects = service
- class _FakeProvider:
- def __init__(self, recorder):
- self._recorder = recorder
- self.storage = _FakeStorage(_FakeService(recorder, self))
- def clone(self, zone=None):
- self._recorder.clone_count += 1
- return _FakeProvider(self._recorder)
- def _get_config_value(self, key, default_value=None):
- return default_value
- class _DriverObject(BaseBucketObject):
- """A BaseBucketObject wired to fakes, with a tiny minimum part size so
- tests can use small payloads."""
- CB_MULTIPART_MIN_PART_SIZE = 1
- def __init__(self, provider, part_size, concurrency):
- super(_DriverObject, self).__init__(provider)
- self._part_size = part_size
- self._concurrency = concurrency
- @property
- def id(self):
- return "obj"
- @property
- def name(self):
- return "obj"
- @property
- def bucket(self):
- return "BUCKET"
- @property
- def _multipart_part_size(self):
- return self._part_size
- @property
- def _multipart_max_concurrency(self):
- return self._concurrency
- class MultipartDriverTestCase(unittest.TestCase):
- def _driver(self, recorder, part_size, concurrency):
- return _DriverObject(_FakeProvider(recorder), part_size, concurrency)
- def test_reassembles_payload_in_order(self):
- recorder = _Recorder()
- driver = self._driver(recorder, part_size=4, concurrency=3)
- content = b"abcdefghijABCDEFGHIJ0123456789x" # 31 bytes -> 8 parts
- result = driver._upload_multipart(BytesIO(content))
- self.assertEqual(result, content)
- self.assertEqual(recorder.completed_order, list(range(1, 9)))
- # Final part is the short remainder (3 bytes).
- self.assertEqual(recorder.parts[8], content[28:])
- def test_handles_short_reads_without_undersized_parts(self):
- recorder = _Recorder()
- driver = self._driver(recorder, part_size=8, concurrency=2)
- class _DripStream:
- """Returns at most 3 bytes per read to simulate a socket-like
- stream; the driver must coalesce reads up to the part size."""
- def __init__(self, data):
- self._buf = BytesIO(data)
- def read(self, size):
- return self._buf.read(min(size, 3))
- content = bytes(range(20)) # 20 bytes, part_size 8 -> 8,8,4
- result = driver._upload_multipart(_DripStream(content))
- self.assertEqual(result, content)
- self.assertEqual(len(recorder.parts[1]), 8)
- self.assertEqual(len(recorder.parts[2]), 8)
- self.assertEqual(len(recorder.parts[3]), 4)
- def test_uploads_parts_concurrently_via_cloned_services(self):
- recorder = _Recorder()
- concurrency = 4
- driver = self._driver(recorder, part_size=1, concurrency=concurrency)
- # 12 parts (one byte each) across a pool of 4 clones.
- content = b"0123456789ab"
- driver._upload_multipart(BytesIO(content))
- # A clone per worker, reused across parts.
- self.assertEqual(recorder.clone_count, concurrency)
- self.assertEqual(len(recorder.services_used), concurrency)
- # Real parallelism happened, bounded by the configured concurrency.
- self.assertGreater(recorder.max_active, 1)
- self.assertLessEqual(recorder.max_active, concurrency)
- def test_single_concurrency_does_not_clone(self):
- recorder = _Recorder()
- driver = self._driver(recorder, part_size=4, concurrency=1)
- driver._upload_multipart(BytesIO(b"abcdefghij"))
- self.assertEqual(recorder.clone_count, 0)
- self.assertEqual(recorder.max_active, 1)
- def test_aborts_and_raises_on_part_failure(self):
- recorder = _Recorder()
- recorder.fail_on_part = 2
- driver = self._driver(recorder, part_size=4, concurrency=3)
- with self.assertRaises(Exception):
- driver._upload_multipart(BytesIO(b"abcdefghijklmnop"))
- self.assertTrue(recorder.aborted)
- def test_part_size_below_minimum_raises(self):
- recorder = _Recorder()
- # Use the real minimum here (not the test override) by going through a
- # driver whose part size is one below the production floor.
- driver = _DriverObject(_FakeProvider(recorder), part_size=4,
- concurrency=2)
- driver.CB_MULTIPART_MIN_PART_SIZE = 5
- with self.assertRaises(InvalidValueException):
- driver._upload_multipart(BytesIO(b"abc"))
- # Nothing should have been created before validation failed.
- self.assertEqual(recorder.completed_order, None)
- if __name__ == "__main__":
- unittest.main()
|