2
0
Эх сурвалжийг харах

Rationalize base-layer casts from review

Address review feedback on the "pattern of unnecessary casts" in base/,
keeping provider-internal implementation details off the public interface:

- get_env: precise @overloads mirroring os.environ.get, replacing Any/Any.
- region_name: widen interface + base to `str | None` (matches the
  underlying value); drop the cast. Guard the one AWS call site that
  assumed non-None with an explicit ProviderInternalException.
- Region.default_zone: declare on the interface so zone_name can read it
  directly, removing the getattr + double-cast in BaseCloudProvider.
- Pageable generics: bound `T` to CloudResource (interface + base) so
  paging code reads `.id` without casting.
- _bucket_objects / _get_config_value stay OFF the public interface (they
  are implementation details). Instead:
  * declare _bucket_objects as an abstract member on BaseStorageService;
  * make BaseCloudResource._provider a covariant override returning
    BaseCloudProvider, so base internals are visible without per-call casts;
  * unify all four _bucket_objects access sites on one meaningful downcast
    to BaseStorageService (removes the Any-hop in subservices and every
    `# type: ignore[attr-defined]`).
- InstanceService.delete / BucketService.delete: declare on the public
  interface for consistency with the eight other services that already do,
  letting Instance.delete()/Bucket.delete() drop their casts.
Nuwan Goonasekera 18 цаг өмнө
parent
commit
33ed360848

+ 12 - 1
cloudbridge/base/helpers.py

@@ -9,6 +9,7 @@ from contextlib import contextmanager
 from typing import Any
 from typing import TypeVar
 from typing import cast
+from typing import overload
 
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization as crypt_serialization
@@ -119,7 +120,17 @@ def cleanup_action(cleanup_func: Callable[[], object]) -> Iterator[None]:
         log.exception("Error during exception cleanup: ")
 
 
-def get_env(varname: str, default_value: Any = None) -> Any:
+@overload
+def get_env(varname: str) -> str | None:
+    ...
+
+
+@overload
+def get_env(varname: str, default_value: T) -> str | T:
+    ...
+
+
+def get_env(varname: str, default_value: object = None) -> object:
     """
     Return the value of the environment variable or default_value.
 

+ 3 - 7
cloudbridge/base/provider.py

@@ -16,7 +16,6 @@ from ..interfaces import CloudProvider
 from ..interfaces.exceptions import ProviderConnectionException
 from ..interfaces.resources import Configuration
 from ..interfaces.resources import PlacementZone
-from ..interfaces.resources import Region
 
 log = logging.getLogger(__name__)
 
@@ -100,17 +99,14 @@ class BaseCloudProvider(CloudProvider):
         self._zone_name: str | None = None
 
     @property
-    def region_name(self) -> str:
-        return cast(str, self._region_name)
+    def region_name(self) -> str | None:
+        return self._region_name
 
     @property
     def zone_name(self) -> str | None:
         if not self._zone_name:
             region = self.compute.regions.current
-            # ``default_zone`` is provided by the concrete Region
-            # implementation rather than the public Region interface.
-            zone = cast("PlacementZone | None",
-                        getattr(cast(Region, region), 'default_zone'))
+            zone = region.default_zone if region else None
             self._zone_name = zone.name if zone else None
             return self._zone_name
         else:

+ 27 - 35
cloudbridge/base/resources.py

@@ -72,13 +72,15 @@ from . import helpers as cb_helpers
 if TYPE_CHECKING:
     from _typeshed import SupportsRead
 
+    from cloudbridge.base.provider import BaseCloudProvider
+    from cloudbridge.base.services import BaseStorageService
     from cloudbridge.interfaces.services import BucketObjectService
 
 log = logging.getLogger(__name__)
 
 # Element type for the generic pageable collections defined in this module
 # (mirrors ``cloudbridge.interfaces.resources.T``).
-T = TypeVar("T")
+T = TypeVar("T", bound=CloudResource)
 
 
 class BaseCloudResource(CloudResource):
@@ -132,8 +134,12 @@ class BaseCloudResource(CloudResource):
         return name
 
     @property
-    def _provider(self) -> CloudProvider:
-        return self.__provider
+    def _provider(self) -> "BaseCloudProvider":
+        # Base resources are always constructed with a base provider, so expose
+        # the base type here. This makes base-layer implementation details
+        # (e.g. ``_get_config_value``) visible to subclasses without per-call
+        # casts, while ``CloudResource._provider`` keeps the public type.
+        return cast("BaseCloudProvider", self.__provider)
 
     def to_json(self) -> dict[str, Any]:
         # Get all attributes but filter methods and private/magic ones
@@ -266,7 +272,7 @@ class ClientPagedResultList(BaseResultList[T]):
         total_size = len(objects)
         if marker:
             from_marker = itertools.dropwhile(
-                lambda obj: not cast(CloudResource, obj).id == marker, objects)
+                lambda obj: not obj.id == marker, objects)
             # skip one past the marker
             next(from_marker, None)
             objects = list(from_marker)
@@ -274,7 +280,7 @@ class ClientPagedResultList(BaseResultList[T]):
         results = list(itertools.islice(objects, limit))
         super(ClientPagedResultList, self).__init__(
             is_truncated,
-            cast(CloudResource, results[-1]).id if is_truncated else None,
+            results[-1].id if is_truncated else None,
             True, total=total_size,
             data=results)
 
@@ -355,9 +361,7 @@ class BaseInstance(BaseCloudResource, BaseObjectLifeCycleMixin, Instance):
             interval=interval)
 
     def delete(self) -> None:
-        # InstanceService.delete is implemented by every provider but is not
-        # declared on the public typed interface, hence the ignore.
-        self._provider.compute.instances.delete(self)  # type: ignore[attr-defined]
+        self._provider.compute.instances.delete(self)
 
 
 class BaseLaunchConfig(LaunchConfig):
@@ -807,11 +811,10 @@ class BaseMultipartUpload(BaseCloudResource, MultipartUpload):
 
     @property
     def _bucket_objects(self) -> "BucketObjectService":
-        # _bucket_objects is a provider-internal service not exposed on the
-        # public StorageService interface, hence the typed cast + ignore.
-        return cast(
-            "BucketObjectService",
-            self._provider.storage._bucket_objects)  # type: ignore[attr-defined]
+        # ``_bucket_objects`` is a base-layer member (BaseStorageService), not
+        # part of the public StorageService interface.
+        storage = cast("BaseStorageService", self._provider.storage)
+        return storage._bucket_objects
 
 
 class BaseBucketObject(BaseCloudResource, BucketObject):
@@ -854,11 +857,10 @@ class BaseBucketObject(BaseCloudResource, BucketObject):
 
     @property
     def _bucket_objects(self) -> "BucketObjectService":
-        # _bucket_objects is a provider-internal service not exposed on the
-        # public StorageService interface, hence the typed cast + ignore.
-        return cast(
-            "BucketObjectService",
-            self._provider.storage._bucket_objects)  # type: ignore[attr-defined]
+        # ``_bucket_objects`` is a base-layer member (BaseStorageService), not
+        # part of the public StorageService interface.
+        storage = cast("BaseStorageService", self._provider.storage)
+        return storage._bucket_objects
 
     @staticmethod
     def is_valid_resource_name(name: str) -> bool:
@@ -888,25 +890,20 @@ class BaseBucketObject(BaseCloudResource, BucketObject):
     def _multipart_threshold(self, config: UploadConfig | None = None) -> int:
         if config is not None and config.threshold is not None:
             return int(config.threshold)
-        # pylint:disable=protected-access
-        # _get_config_value is a provider-internal helper not on the public
-        # CloudProvider interface, hence the ignore.
-        return int(self._provider._get_config_value(  # type: ignore[attr-defined]
+        return int(self._provider._get_config_value(
             'multipart_threshold', self.CB_MULTIPART_THRESHOLD))
 
     def _multipart_part_size(self, config: UploadConfig | None = None) -> int:
         if config is not None and config.part_size is not None:
             return int(config.part_size)
-        # pylint:disable=protected-access
-        return int(self._provider._get_config_value(  # type: ignore[attr-defined]
+        return int(self._provider._get_config_value(
             'multipart_part_size', self.CB_MULTIPART_PART_SIZE))
 
     def _multipart_max_concurrency(
             self, config: UploadConfig | None = None) -> int:
         if config is not None and config.max_concurrency is not None:
             return int(config.max_concurrency)
-        # pylint:disable=protected-access
-        return int(self._provider._get_config_value(  # type: ignore[attr-defined]
+        return int(self._provider._get_config_value(
             'multipart_max_concurrency', self.CB_MULTIPART_MAX_CONCURRENCY))
 
     @staticmethod
@@ -1005,12 +1002,9 @@ class BaseBucketObject(BaseCloudResource, BucketObject):
         # thread touches an isolated provider/connection.
         clones: "queue.Queue[BucketObjectService]" = queue.Queue()
         for _ in range(concurrency):
-            # pylint:disable=protected-access
-            # _bucket_objects is a provider-internal service not exposed on the
-            # public StorageService interface, hence the typed cast + ignore.
-            clones.put(cast(
-                "BucketObjectService",
-                self._provider.clone().storage._bucket_objects))  # type: ignore[attr-defined]
+            storage = cast("BaseStorageService",
+                           self._provider.clone().storage)
+            clones.put(storage._bucket_objects)
 
         def upload_one(part_number: int, chunk: bytes) -> UploadPart:
             service = clones.get()
@@ -1103,9 +1097,7 @@ class BaseBucket(BaseCloudResource, Bucket):
         if delete_contents:
             for obj in self.objects:
                 obj.delete()
-        # BucketService.delete is implemented by every provider but is not
-        # declared on the public typed interface, hence the ignore.
-        self._provider.storage.buckets.delete(self.id)  # type: ignore[attr-defined]
+        self._provider.storage.buckets.delete(self.id)
 
     # TODO: Discuss creating `create_object` method, or change docs
 

+ 15 - 0
cloudbridge/base/services.py

@@ -2,6 +2,7 @@
 Base implementation for services available through a provider
 """
 import logging
+from abc import abstractmethod
 from typing import Any
 from typing import cast
 
@@ -162,6 +163,20 @@ class BaseStorageService(StorageService, BaseCloudService):
     def __init__(self, provider: CloudProvider) -> None:
         super(BaseStorageService, self).__init__(provider)
 
+    @property
+    @abstractmethod
+    def _bucket_objects(self) -> BucketObjectService:
+        """
+        Provider-internal service backing bucket-object operations.
+
+        This is the service that ``bucket.objects`` (BucketObjectSubService)
+        and the base multipart-upload code delegate to. It is a base-layer
+        implementation detail, deliberately not part of the public
+        StorageService interface; every provider's storage service implements
+        it.
+        """
+        pass
+
 
 class BaseVolumeService(
         BasePageableObjectMixin[Volume], VolumeService, BaseCloudService):

+ 8 - 4
cloudbridge/base/subservices.py

@@ -1,6 +1,7 @@
 import builtins
 import logging
 from typing import Any
+from typing import TYPE_CHECKING
 from typing import cast
 
 from cloudbridge.interfaces.provider import CloudProvider
@@ -27,6 +28,9 @@ from cloudbridge.interfaces.subservices import VMFirewallRuleSubService
 
 from .resources import BasePageableObjectMixin
 
+if TYPE_CHECKING:
+    from .services import BaseStorageService
+
 log = logging.getLogger(__name__)
 
 
@@ -43,10 +47,10 @@ class BaseBucketObjectSubService(BasePageableObjectMixin[BucketObject],
 
     @property
     def _bucket_objects(self) -> BucketObjectService:
-        # ``_bucket_objects`` is a provider-internal service not declared on
-        # the StorageService interface; reach it through ``Any``.
-        storage: Any = self._provider.storage
-        return cast(BucketObjectService, storage._bucket_objects)
+        # ``_bucket_objects`` is a base-layer member (BaseStorageService), not
+        # part of the public StorageService interface.
+        storage = cast("BaseStorageService", self._provider.storage)
+        return storage._bucket_objects
 
     def get(self, name: str) -> BucketObject | None:
         return self._bucket_objects.get(self.bucket, name)

+ 1 - 1
cloudbridge/interfaces/provider.py

@@ -165,7 +165,7 @@ class CloudProvider(object):
         pass
 
     @abstractproperty
-    def region_name(self) -> str:
+    def region_name(self) -> str | None:
         """
         Returns the region that this provider is connected to.
         All provider operations will take place within this region.

+ 14 - 2
cloudbridge/interfaces/resources.py

@@ -25,8 +25,10 @@ if TYPE_CHECKING:
     from cloudbridge.interfaces.subservices import SubnetSubService
     from cloudbridge.interfaces.subservices import VMFirewallRuleSubService
 
-# Element type for pageable collections (services and ResultList).
-T = TypeVar("T")
+# Element type for pageable collections (services and ResultList). Every such
+# element is a CloudResource, so the bound lets paging code read `.id` without
+# casting.
+T = TypeVar("T", bound="CloudResource")
 
 
 class CloudServiceType(object):
@@ -1892,6 +1894,16 @@ class Region(CloudResource):
         """
         pass
 
+    @abstractproperty
+    def default_zone(self) -> PlacementZone:
+        """
+        Access the default placement zone for this region.
+
+        :rtype: :class:`.PlacementZone`
+        :return: The default placement zone for this region.
+        """
+        pass
+
 
 class PlacementZone(CloudResource):
     """

+ 20 - 0
cloudbridge/interfaces/services.py

@@ -322,6 +322,16 @@ class InstanceService(PageableObjectMixin[Instance], CloudService):
         """
         pass
 
+    @abstractmethod
+    def delete(self, instance: Instance | str) -> None:
+        """
+        Permanently delete an instance.
+
+        :type instance: :class:`.Instance` or ``str``
+        :param instance: The object or ID of the instance to be deleted.
+        """
+        pass
+
 
 class VolumeService(PageableObjectMixin[Volume], CloudService):
     """
@@ -1178,6 +1188,16 @@ class BucketService(PageableObjectMixin[Bucket], CloudService):
         """
         pass
 
+    @abstractmethod
+    def delete(self, bucket: Bucket | str) -> None:
+        """
+        Delete a bucket.
+
+        :type bucket: :class:`.Bucket` or ``str``
+        :param bucket: The object or ID of the bucket to be deleted.
+        """
+        pass
+
 
 class BucketObjectService(CloudService):
 

+ 7 - 3
cloudbridge/providers/aws/services.py

@@ -48,6 +48,7 @@ from cloudbridge.interfaces.exceptions import \
     InvalidConfigurationException
 from cloudbridge.interfaces.exceptions import InvalidParamException
 from cloudbridge.interfaces.exceptions import InvalidValueException
+from cloudbridge.interfaces.exceptions import ProviderInternalException
 from cloudbridge.interfaces.resources import Bucket
 from cloudbridge.interfaces.resources import BucketObject
 from cloudbridge.interfaces.resources import DnsRecord
@@ -1366,9 +1367,12 @@ class AWSSubnetService(BaseSubnetService):
         #     default_router = default_routers[0]
 
         # Create a subnet in each of the region's zones
-        region = cast(
-            Region, self.provider.compute.regions.get(
-                self.provider.region_name))
+        region_name = self.provider.region_name
+        if region_name is None:
+            raise ProviderInternalException(
+                "Cannot create default network resources: provider has no "
+                "region")
+        region = cast(Region, self.provider.compute.regions.get(region_name))
         default_sn = None
 
         # Determine how many subnets we'll need for the default network and the